DecisionNode.cpp 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * @file DecisionNode.cpp
  3. * @brief decision node
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #include <iostream>
  8. #include "core/image/ImageT.h"
  9. #include "core/imagedisplay/ImageDisplay.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/DecisionNode.h"
  11. using namespace OBJREC;
  12. using namespace std;
  13. using namespace NICE;
  14. DecisionNode::~DecisionNode()
  15. {
  16. if ( f != NULL )
  17. delete f;
  18. }
  19. DecisionNode::DecisionNode ()
  20. {
  21. left = NULL;
  22. right = NULL;
  23. f = NULL;
  24. counter = 0;
  25. }
  26. DecisionNode *DecisionNode::getLeafNode (
  27. const Example & ce,
  28. int depth )
  29. {
  30. counter += ce.weight;
  31. if ( (!depth) || ((left == NULL) && (right == NULL)) )
  32. return this;
  33. double val = f->val( &ce );
  34. if ( val < threshold )
  35. if ( left != NULL )
  36. return left->getLeafNode ( ce, depth - 1 );
  37. else
  38. return this;
  39. else
  40. if ( right != NULL )
  41. return right->getLeafNode ( ce, depth - 1 );
  42. else
  43. return this;
  44. }
  45. void DecisionNode::traverse (
  46. const Example & ce,
  47. FullVector & _distribution )
  48. {
  49. DecisionNode *leaf = getLeafNode ( ce );
  50. _distribution = leaf->distribution;
  51. }
  52. void DecisionNode::statistics ( int & depth, int & count ) const
  53. {
  54. int dl, cl;
  55. if ( left != NULL )
  56. {
  57. left->statistics ( dl, cl );
  58. dl++;
  59. } else {
  60. dl = 0;
  61. cl = 0;
  62. }
  63. if ( right != NULL )
  64. {
  65. right->statistics ( depth, count );
  66. depth++;
  67. } else {
  68. depth = 0;
  69. count = 0;
  70. }
  71. depth = (depth > dl) ? depth : dl;
  72. count += cl + 1;
  73. }
  74. void DecisionNode::indexDescendants ( map<DecisionNode *, pair<long, int> > & index, long & maxindex, int depth ) const
  75. {
  76. if ( left != NULL )
  77. {
  78. maxindex++;
  79. index.insert ( pair<DecisionNode *, pair<long, int> > ( left, pair<long, int>( maxindex, depth + 1 ) ) );
  80. left->indexDescendants ( index, maxindex, depth+1 );
  81. }
  82. if ( right != NULL )
  83. {
  84. maxindex++;
  85. index.insert ( pair<DecisionNode *, pair<long, int> > ( right, pair<long, int>( maxindex, depth + 1 ) ) );
  86. right->indexDescendants ( index, maxindex, depth+1 );
  87. }
  88. }
  89. void DecisionNode::resetCounters ()
  90. {
  91. counter = 0;
  92. if ( left != NULL ) left->resetCounters();
  93. if ( right != NULL ) right->resetCounters();
  94. }
  95. void DecisionNode::copy ( DecisionNode *node )
  96. {
  97. left = node->left;
  98. right = node->right;
  99. threshold = node->threshold;
  100. f = node->f;
  101. distribution = node->distribution;
  102. trainExamplesIndices = node->trainExamplesIndices;
  103. }
  104. bool DecisionNode::isLeaf () const
  105. {
  106. return ( (right == NULL) && (left == NULL) );
  107. }