RegRandomForests.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. /**
  2. * @file RegRandomForests.h
  3. * @brief implementation of random set forest for regression
  4. * @author Sven Sickert
  5. * @date 06/19/2013
  6. */
  7. #ifndef REGRANDOMFORESTSINCLUDE
  8. #define REGRANDOMFORESTSINCLUDE
  9. #include <vector>
  10. #include "core/vector/VectorT.h"
  11. #include "core/vector/MatrixT.h"
  12. #include "vislearning/regression/regressionbase/RegressionAlgorithm.h"
  13. #include "vislearning/regression/randomforest/RegressionTree.h"
  14. #include "vislearning/regression/randomforest/RegressionTreeBuilder.h"
  15. namespace OBJREC
  16. {
  17. /** implementation of random set forests for regression */
  18. class RegRandomForests : public RegressionAlgorithm
  19. {
  20. protected:
  21. /** vector containing all decision trees for regression */
  22. std::vector<RegressionTree *> forest;
  23. /** number of trees which will be generated during training */
  24. int number_of_trees;
  25. /** fraction of features used for each tree */
  26. double features_per_tree;
  27. /** fraction of training examples used for each tree */
  28. double samples_per_tree;
  29. /** if >0 then prune the trees using pruneTreeLeastSquares */
  30. double minimum_error_reduction;
  31. /** stored config to initialize a tree */
  32. const NICE::Config *conf;
  33. /** config section containing important config values */
  34. std::string confsection;
  35. /** pointer to the tree builder method */
  36. RegressionTreeBuilder *builder;
  37. /** calculate out-of-bag statistics or not */
  38. bool enableOutOfBagEstimates;
  39. /** out-of-bag statistics */
  40. std::vector<std::pair<double, double> > oobResults;
  41. /** predict using only a subset of all trees */
  42. double predict ( const NICE::Vector & x,
  43. const std::vector<int> & outofbagtrees );
  44. /** calculate out-of-bag statistics */
  45. void calcOutOfBagEstimates ( std::vector< std::vector<int> > & outofbagtrees,
  46. NICE::VVector x,
  47. NICE::Vector y );
  48. /** save example selection per tree */
  49. std::vector<std::vector<int> > exselection;
  50. public:
  51. /** initialize the regression method */
  52. RegRandomForests ( const NICE::Config *conf,
  53. std::string section );
  54. /** do nothing */
  55. RegRandomForests ();
  56. /** simple destructor */
  57. virtual ~RegRandomForests();
  58. /** learn parameters/models/whatever using a set of vectors and
  59. * their corresponding function values
  60. */
  61. void teach ( const NICE::VVector & x, const NICE::Vector & y );
  62. /** main prediction function */
  63. double predict ( const NICE::Vector & x );
  64. /** get all leaf nodes for a given value (or inner nodes if depth is set to the level) */
  65. void getLeafNodes ( NICE::Vector x,
  66. std::vector<RegressionNode *> & leafNodes,
  67. int depth = 100000 );
  68. /** get all leaf nodes (or inner nodes if depth is set to the level) */
  69. void getAllLeafNodes ( std::vector<RegressionNode *> & leafNodes );
  70. /** enumerate all nodes within the trees */
  71. void indexDescendants ( std::map<RegressionNode *, std::pair<long, int> > & index ) const;
  72. /** reset all counters in all nodes contained in the forest */
  73. void resetCounters ();
  74. /** clone function */
  75. virtual RegRandomForests *clone ( void ) const
  76. {
  77. fthrow ( NICE::Exception, "clone() not yet implemented!\n" );
  78. }
  79. /** get out of bag estimates */
  80. std::vector<std::pair<double, double> > & getOutOfBagResults ()
  81. {
  82. return oobResults;
  83. };
  84. /** set the number of trees */
  85. void setComplexity ( int size );
  86. /** IO functions */
  87. void restore ( std::istream & is, int format = 0 );
  88. void store ( std::ostream & os, int format = 0 ) const;
  89. void clear ();
  90. };
  91. } // namespace
  92. #endif