DecisionNode.cpp 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * @file DecisionNode.cpp
  3. * @brief decision node
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #include <vislearning/nice.h>
  8. #include <iostream>
  9. #include "vislearning/classifier/fpclassifier/randomforest/DecisionNode.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. DecisionNode::~DecisionNode()
  14. {
  15. if ( f != NULL )
  16. delete f;
  17. }
  18. DecisionNode::DecisionNode ()
  19. {
  20. left = NULL;
  21. right = NULL;
  22. f = NULL;
  23. counter = 0;
  24. }
  25. DecisionNode *DecisionNode::getLeafNode (
  26. const Example & ce,
  27. int depth )
  28. {
  29. counter += ce.weight;
  30. if ( (!depth) || ((left == NULL) && (right == NULL)) )
  31. return this;
  32. double val = f->val( &ce );
  33. if ( val < threshold )
  34. if ( left != NULL )
  35. return left->getLeafNode ( ce, depth - 1 );
  36. else
  37. return this;
  38. else
  39. if ( right != NULL )
  40. return right->getLeafNode ( ce, depth - 1 );
  41. else
  42. return this;
  43. }
  44. void DecisionNode::traverse (
  45. const Example & ce,
  46. FullVector & _distribution )
  47. {
  48. DecisionNode *leaf = getLeafNode ( ce );
  49. _distribution = leaf->distribution;
  50. }
  51. void DecisionNode::statistics ( int & depth, int & count ) const
  52. {
  53. int dl, cl;
  54. if ( left != NULL )
  55. {
  56. left->statistics ( dl, cl );
  57. dl++;
  58. } else {
  59. dl = 0;
  60. cl = 0;
  61. }
  62. if ( right != NULL )
  63. {
  64. right->statistics ( depth, count );
  65. depth++;
  66. } else {
  67. depth = 0;
  68. count = 0;
  69. }
  70. depth = (depth > dl) ? depth : dl;
  71. count += cl + 1;
  72. }
  73. void DecisionNode::indexDescendants ( map<DecisionNode *, pair<long, int> > & index, long & maxindex, int depth ) const
  74. {
  75. if ( left != NULL )
  76. {
  77. maxindex++;
  78. index.insert ( pair<DecisionNode *, pair<long, int> > ( left, pair<long, int>( maxindex, depth + 1 ) ) );
  79. left->indexDescendants ( index, maxindex, depth+1 );
  80. }
  81. if ( right != NULL )
  82. {
  83. maxindex++;
  84. index.insert ( pair<DecisionNode *, pair<long, int> > ( right, pair<long, int>( maxindex, depth + 1 ) ) );
  85. right->indexDescendants ( index, maxindex, depth+1 );
  86. }
  87. }
  88. void DecisionNode::resetCounters ()
  89. {
  90. counter = 0;
  91. if ( left != NULL ) left->resetCounters();
  92. if ( right != NULL ) right->resetCounters();
  93. }
  94. void DecisionNode::copy ( DecisionNode *node )
  95. {
  96. left = node->left;
  97. right = node->right;
  98. threshold = node->threshold;
  99. f = node->f;
  100. distribution = node->distribution;
  101. trainExamplesIndices = node->trainExamplesIndices;
  102. }
  103. bool DecisionNode::isLeaf () const
  104. {
  105. return ( (right == NULL) && (left == NULL) );
  106. }