RegressionNode.h 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /**
  2. * @file RegressionNode.h
  3. * @brief regression node
  4. * @author Sven Sickert
  5. * @date 06/19/2013
  6. */
  7. #ifndef REGRESSIONNODEINCLUDE
  8. #define REGRESSIONNODEINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include <map>
  12. #include <limits>
  13. namespace OBJREC {
  14. /** regression node: f(x) < threshold ? */
  15. class RegressionNode
  16. {
  17. protected:
  18. public:
  19. /** threshold of the regression node */
  20. double threshold;
  21. /** counter which can be used to
  22. count the number of examples which reached the node */
  23. double counter;
  24. /** the feature used for the regression node split */
  25. int f;
  26. /** the least squares error of the node */
  27. double lsError;
  28. /** the prediction value of the node */
  29. double predVal;
  30. /** the left branch of the tree */
  31. RegressionNode *left;
  32. /** the right branch of the tree */
  33. RegressionNode *right;
  34. /** Indices of examples which were used to estimate the
  35. * prediction value during training */
  36. std::vector<int> trainExamplesIndices;
  37. /** constructor */
  38. RegressionNode ();
  39. /** simple destructor */
  40. virtual ~RegressionNode();
  41. /** traverse the tree and get the resulting leaf node */
  42. RegressionNode *getLeafNode ( const NICE::Vector & x,
  43. int depth = std::numeric_limits<int>::max() );
  44. /** traverse this node with an example */
  45. void traverse ( const NICE::Vector & x,
  46. double & predVal );
  47. /** calculate the overall statistic of the current branch */
  48. void statistics ( int & depth, int & count ) const;
  49. /** only index descendants (with > depth), do not index node itsself */
  50. void indexDescendants ( std::map<RegressionNode *,
  51. std::pair<long, int> > & index,
  52. long & maxindex,
  53. int depth ) const;
  54. /** calculate the prediction value for this node */
  55. void nodePrediction( const NICE::Vector & y,
  56. const std::vector<int> & selection);
  57. /** reset the counters variable of the current branch */
  58. void resetCounters ();
  59. /** copy the node information to another node */
  60. void copy ( RegressionNode *node );
  61. /** is this node a leaf */
  62. bool isLeaf () const;
  63. };
  64. } // namespace
  65. #endif