FPCRandomForests.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * @file FPCRandomForests.h
  3. * @brief implementation of random set forests
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #ifndef FPCRANDOMFORESTSINCLUDE
  8. #define FPCRANDOMFORESTSINCLUDE
  9. #include <vector>
  10. #include "core/vector/VectorT.h"
  11. #include "core/vector/MatrixT.h"
  12. #include "core/image/ImageT.h"
  13. #include "core/imagedisplay/ImageDisplay.h"
  14. #include "vislearning/classifier/classifierbase/FeaturePoolClassifier.h"
  15. #include "vislearning/cbaselib/FeaturePool.h"
  16. #include "DecisionTree.h"
  17. #include "DecisionTreeBuilder.h"
  18. namespace OBJREC
  19. {
  20. /** implementation of random set forests */
  21. class FPCRandomForests : public FeaturePoolClassifier
  22. {
  23. protected:
  24. /** vector containing all decision trees */
  25. std::vector<DecisionTree *> forest;
  26. /** number of trees which will be generated in the
  27. during training */
  28. int number_of_trees;
  29. /** fraction of features used for each tree */
  30. double features_per_tree;
  31. /** fraction of training examples used for each tree */
  32. double samples_per_tree;
  33. /** use an equal number of training examples of each class
  34. to build a single tree */
  35. bool use_simple_balancing;
  36. /** weight examples according to a priori class probabilities
  37. as estimated using the distribution contained in the training data */
  38. bool weight_examples;
  39. /** if >0 then prune the trees using pruneTreeEntropy */
  40. double minimum_entropy;
  41. /** clear all examples after building a tree, this deletes
  42. all cached images contained in CachedExample etc. */
  43. bool memory_efficient;
  44. /** stored config to initialize a tree */
  45. const NICE::Config *conf;
  46. /** config section containing important config values */
  47. std::string confsection;
  48. /** pointer to the tree builder method */
  49. DecisionTreeBuilder *builder;
  50. /** out-of-bag statistics */
  51. bool enableOutOfBagEstimates;
  52. std::vector<std::pair<double, int> > oobResults;
  53. /** classify using only a subset of all trees */
  54. ClassificationResult classify ( Example & pce,
  55. const std::vector<int> & outofbagtrees );
  56. /** calculate out-of-bag statistics */
  57. void calcOutOfBagEstimates ( std::vector< std::vector<int> > & outofbagtrees,
  58. Examples & examples );
  59. /** save example selection per tree */
  60. std::vector<std::vector<int> > exselection;
  61. public:
  62. /** initialize the classifier */
  63. FPCRandomForests ( const NICE::Config *conf,
  64. std::string section );
  65. /** do nothing */
  66. FPCRandomForests ();
  67. /** simple destructor */
  68. virtual ~FPCRandomForests();
  69. /** main classification function */
  70. ClassificationResult classify ( Example & pce );
  71. int classify_optimize ( Example & pce );
  72. /** get all leaf nodes for an given example (or inner nodes if depth is set to the level) */
  73. void getLeafNodes ( Example & pce,
  74. std::vector<DecisionNode *> & leafNodes,
  75. int depth = 100000 );
  76. /** get all leaf nodes (or inner nodes if depth is set to the level) */
  77. void getAllLeafNodes ( std::vector<DecisionNode *> & leafNodes );
  78. /** perform training using a given feature pool and some training data */
  79. virtual void train ( FeaturePool & fp,
  80. Examples & examples );
  81. /** enumerate all nodes within the trees */
  82. void indexDescendants ( std::map<DecisionNode *, std::pair<long, int> > & index ) const;
  83. /** reset all counters in all nodes contained in the forest */
  84. void resetCounters ();
  85. /** direct access to all trees */
  86. const std::vector<DecisionTree *> & getForest () const
  87. {
  88. return forest;
  89. };
  90. /** direct write access to all trees */
  91. std::vector<DecisionTree *> & getForestNonConst ()
  92. {
  93. return forest;
  94. };
  95. /** clone this object */
  96. FeaturePoolClassifier *clone () const;
  97. /** get out of bag estimates */
  98. std::vector<std::pair<double, int> > & getOutOfBagResults ()
  99. {
  100. return oobResults;
  101. };
  102. /** set the number of trees */
  103. void setComplexity ( int size );
  104. /** IO functions */
  105. void restore ( std::istream & is, int format = 0 );
  106. void store ( std::ostream & os, int format = 0 ) const;
  107. void clear ();
  108. };
  109. } // namespace
  110. #endif