RegressionTree.h 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /**
  2. * @file RegressionTree.h
  3. * @brief regression tree implementation for regression
  4. * @author Sven Sickert
  5. * @date 06/19/2013
  6. */
  7. #ifndef REGRESSIONTREEINCLUDE
  8. #define REGRESSIONTREEINCLUDE
  9. #include <map>
  10. #include <set>
  11. #include "core/vector/VectorT.h"
  12. #include "core/vector/MatrixT.h"
  13. #include "core/basics/triplet.h"
  14. #include "core/basics/Config.h"
  15. #include "core/basics/Persistent.h"
  16. #include "vislearning/regression/randomforest/RegressionNode.h"
  17. namespace OBJREC {
  18. /** decision tree implementation for regression */
  19. class RegressionTree : public NICE::Persistent
  20. {
  21. protected:
  22. RegressionNode *root;
  23. const NICE::Config *conf; // for restore operation
  24. public:
  25. static void deleteNodes ( RegressionNode *tree );
  26. static RegressionNode *pruneTreeLeastSquares (
  27. RegressionNode *node,
  28. double minErrorReduction,
  29. double & lsError );
  30. /** simple consructor */
  31. RegressionTree( const NICE::Config *conf );
  32. /** simple destructor */
  33. virtual ~RegressionTree();
  34. void traverse ( const NICE::Vector & x,
  35. double & predVal );
  36. void resetCounters ();
  37. void statistics( int & depth, int & count ) const;
  38. void indexDescendants ( std::map<RegressionNode *, std::pair<long, int> > & index,
  39. long & maxindex ) const;
  40. RegressionNode *getLeafNode ( NICE::Vector & x,
  41. int maxdepth = 100000 );
  42. void getLeaves ( RegressionNode *node, std::vector<RegressionNode*> &leaves);
  43. std::vector<RegressionNode *> getAllLeafNodes ();
  44. RegressionNode *getRoot( ) const { return root; };
  45. void pruneTreeLeastSquares ( double minErrorReduction );
  46. void setRoot( RegressionNode *newroot );
  47. void restore (std::istream & is, int format = 0);
  48. void store (std::ostream & os, int format = 0) const;
  49. void clear ();
  50. };
  51. } // namespace
  52. #endif