DecisionNode.cpp 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. /**
  2. * @file DecisionNode.cpp
  3. * @brief decision node
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #include <iostream>
  8. #include "vislearning/classifier/fpclassifier/randomforest/DecisionNode.h"
  9. using namespace OBJREC;
  10. using namespace std;
  11. using namespace NICE;
  12. DecisionNode::~DecisionNode()
  13. {
  14. if ( f != NULL )
  15. delete f;
  16. }
  17. DecisionNode::DecisionNode ()
  18. {
  19. left = NULL;
  20. right = NULL;
  21. f = NULL;
  22. counter = 0;
  23. }
  24. DecisionNode *DecisionNode::getLeafNode (
  25. const Example & ce,
  26. int depth )
  27. {
  28. counter += ce.weight;
  29. if ( (!depth) || ((left == NULL) && (right == NULL)) )
  30. return this;
  31. double val = f->val( &ce );
  32. if ( val < threshold )
  33. if ( left != NULL )
  34. return left->getLeafNode ( ce, depth - 1 );
  35. else
  36. return this;
  37. else
  38. if ( right != NULL )
  39. return right->getLeafNode ( ce, depth - 1 );
  40. else
  41. return this;
  42. }
  43. void DecisionNode::traverse (
  44. const Example & ce,
  45. FullVector & _distribution )
  46. {
  47. DecisionNode *leaf = getLeafNode ( ce );
  48. _distribution = leaf->distribution;
  49. }
  50. void DecisionNode::statistics ( int & depth, int & count ) const
  51. {
  52. int dl, cl;
  53. if ( left != NULL )
  54. {
  55. left->statistics ( dl, cl );
  56. dl++;
  57. } else {
  58. dl = 0;
  59. cl = 0;
  60. }
  61. if ( right != NULL )
  62. {
  63. right->statistics ( depth, count );
  64. depth++;
  65. } else {
  66. depth = 0;
  67. count = 0;
  68. }
  69. depth = (depth > dl) ? depth : dl;
  70. count += cl + 1;
  71. }
  72. void DecisionNode::indexDescendants ( map<DecisionNode *, pair<long, int> > & index, long & maxindex, int depth ) const
  73. {
  74. if ( left != NULL )
  75. {
  76. maxindex++;
  77. index.insert ( pair<DecisionNode *, pair<long, int> > ( left, pair<long, int>( maxindex, depth + 1 ) ) );
  78. left->indexDescendants ( index, maxindex, depth+1 );
  79. }
  80. if ( right != NULL )
  81. {
  82. maxindex++;
  83. index.insert ( pair<DecisionNode *, pair<long, int> > ( right, pair<long, int>( maxindex, depth + 1 ) ) );
  84. right->indexDescendants ( index, maxindex, depth+1 );
  85. }
  86. }
  87. void DecisionNode::resetCounters ()
  88. {
  89. counter = 0;
  90. if ( left != NULL ) left->resetCounters();
  91. if ( right != NULL ) right->resetCounters();
  92. }
  93. void DecisionNode::copy ( DecisionNode *node )
  94. {
  95. left = node->left;
  96. right = node->right;
  97. threshold = node->threshold;
  98. f = node->f;
  99. distribution = node->distribution;
  100. trainExamplesIndices = node->trainExamplesIndices;
  101. }
  102. bool DecisionNode::isLeaf () const
  103. {
  104. return ( (right == NULL) && (left == NULL) );
  105. }