RegressionTree.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. /**
  2. * @file RegressionTree.cpp
  3. * @brief regression tree implementation
  4. * @author Sven Sickert
  5. * @date 06/19/2013
  6. */
  7. #include <iostream>
  8. #include <assert.h>
  9. #include "vislearning/regression/randomforest/RegressionTree.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. RegressionTree::RegressionTree( const Config *_conf ) : conf(_conf)
  14. {
  15. root = NULL;
  16. }
  17. RegressionTree::~RegressionTree()
  18. {
  19. deleteNodes ( root );
  20. }
  21. void RegressionTree::statistics ( int & depth, int & count ) const
  22. {
  23. if ( root == NULL )
  24. {
  25. depth = 0;
  26. count = 0;
  27. } else {
  28. root->statistics ( depth, count );
  29. }
  30. }
  31. void RegressionTree::traverse (
  32. const Vector & x,
  33. double & predVal )
  34. {
  35. assert( root != NULL );
  36. root->traverse ( x, predVal );
  37. }
  38. void RegressionTree::deleteNodes ( RegressionNode *tree )
  39. {
  40. if ( tree != NULL )
  41. {
  42. deleteNodes ( tree->left );
  43. deleteNodes ( tree->right );
  44. delete tree;
  45. }
  46. }
  47. void RegressionTree::clear ()
  48. {
  49. deleteNodes ( root );
  50. }
  51. void RegressionTree::resetCounters ()
  52. {
  53. if ( root != NULL )
  54. root->resetCounters ();
  55. }
  56. void RegressionTree::indexDescendants (
  57. map<RegressionNode *, pair<long, int> > & index,
  58. long & maxindex ) const
  59. {
  60. if ( root != NULL )
  61. root->indexDescendants ( index, maxindex, 0 );
  62. }
  63. RegressionNode *RegressionTree::getLeafNode (
  64. Vector & x,
  65. int maxdepth )
  66. {
  67. return root->getLeafNode ( x, maxdepth );
  68. }
  69. void RegressionTree::getLeaves(
  70. RegressionNode *node,
  71. vector<RegressionNode*> &leaves)
  72. {
  73. if(node->left == NULL && node->right == NULL)
  74. {
  75. leaves.push_back(node);
  76. return;
  77. }
  78. getLeaves(node->right, leaves);
  79. getLeaves(node->left, leaves);
  80. }
  81. vector<RegressionNode *> RegressionTree::getAllLeafNodes()
  82. {
  83. vector<RegressionNode*> leaves;
  84. getLeaves(root, leaves);
  85. return leaves;
  86. }
  87. void RegressionTree::setRoot ( RegressionNode *newroot )
  88. {
  89. root = newroot;
  90. }
  91. RegressionNode *RegressionTree::pruneTreeLeastSquares (
  92. RegressionNode *node,
  93. double minErrorReduction,
  94. double & lsError )
  95. {
  96. if ( node == NULL ) return NULL;
  97. lsError = node->lsError;
  98. double leftError, rightError;
  99. node->left = pruneTreeLeastSquares ( node->left, minErrorReduction, leftError );
  100. node->right = pruneTreeLeastSquares ( node->right, minErrorReduction, rightError );
  101. if (node->left != NULL && node->right != NULL)
  102. {
  103. if (lsError-leftError-rightError < minErrorReduction)
  104. {
  105. deleteNodes( node->left );
  106. deleteNodes( node->right );
  107. }
  108. }
  109. return node;
  110. }
  111. void RegressionTree::pruneTreeLeastSquares ( double minErrorReduction )
  112. {
  113. int depth, count;
  114. statistics ( depth, count );
  115. fprintf (stderr, "RegressionTree::pruneTreeLeastSquares: depth %d count %d\n", depth, count );
  116. double tmp;
  117. root = pruneTreeLeastSquares ( root, minErrorReduction, tmp );
  118. statistics ( depth, count );
  119. fprintf (stderr, "RegressionTree::pruneTreeLeastSquares: depth %d count %d (modified)\n", depth, count );
  120. }
  121. void RegressionTree::store (ostream & os, int format) const
  122. {
  123. if ( root == NULL ) return;
  124. // indexing
  125. map<RegressionNode *, pair<long, int> > index;
  126. index.insert ( pair<RegressionNode *, pair<long, int> > ( NULL, pair<long, int> ( 0, 0 ) ) );
  127. index.insert ( pair<RegressionNode *, pair<long, int> > ( root, pair<long, int> ( 1, 0 ) ) );
  128. long maxindex = 1;
  129. root->indexDescendants ( index, maxindex, 0 );
  130. for ( map<RegressionNode *, pair<long, int> >::iterator i = index.begin();
  131. i != index.end();
  132. i++ )
  133. {
  134. RegressionNode *node = i->first;
  135. if ( node == NULL ) continue;
  136. long ind = i->second.first;
  137. long ind_l = index[ node->left ].first;
  138. long ind_r = index[ node->right ].first;
  139. os << "NODE " << ind << " " << ind_l << " " << ind_r << endl;
  140. if ( !node->isLeaf() ) {
  141. os << node->f;
  142. os << endl;
  143. os << node->threshold;
  144. os << endl;
  145. } else {
  146. os << "LEAF";
  147. os << endl;
  148. }
  149. os << node->lsError << " " << -1 << endl;
  150. }
  151. }
  152. void RegressionTree::restore (istream & is, int format)
  153. {
  154. // indexing
  155. map<long, RegressionNode *> index;
  156. map<long, pair<long, long> > descendants;
  157. index.insert ( pair<long, RegressionNode *> ( 0, NULL ) );
  158. // refactor-nice.pl: check this substitution
  159. // old: string tag;
  160. std::string tag;
  161. while ( (! is.eof()) && ( (is >> tag) && (tag == "NODE") ) )
  162. {
  163. long ind;
  164. long ind_l;
  165. long ind_r;
  166. if (! (is >> ind)) break;
  167. if (! (is >> ind_l)) break;
  168. if (! (is >> ind_r)) break;
  169. descendants.insert ( pair<long, pair<long, long> > ( ind, pair<long, long> ( ind_l, ind_r ) ) );
  170. RegressionNode *node = new RegressionNode();
  171. index.insert ( pair<long, RegressionNode *> ( ind, node ) );
  172. std::string feature_tag;
  173. is >> feature_tag;
  174. if ( feature_tag != "LEAF" )
  175. {
  176. is >> node->f;
  177. is >> node->threshold;
  178. }
  179. is >> node->lsError;
  180. }
  181. // connecting the tree
  182. for ( map<long, RegressionNode *>::const_iterator it = index.begin();
  183. it != index.end(); it++ )
  184. {
  185. RegressionNode *node = it->second;
  186. if ( node == NULL ) continue;
  187. long ind_l = descendants[it->first].first;
  188. long ind_r = descendants[it->first].second;
  189. map<long, RegressionNode *>::const_iterator il = index.find ( ind_l );
  190. map<long, RegressionNode *>::const_iterator ir = index.find ( ind_r );
  191. if ( ( il == index.end() ) || ( ir == index.end() ) )
  192. {
  193. fprintf (stderr, "File inconsistent: unable to build tree\n");
  194. exit(-1);
  195. }
  196. RegressionNode *left = il->second;
  197. RegressionNode *right = ir->second;
  198. node->left = left;
  199. node->right = right;
  200. }
  201. map<long, RegressionNode *>::const_iterator iroot = index.find ( 1 );
  202. if ( iroot == index.end() )
  203. {
  204. fprintf (stderr, "File inconsistent: unable to build tree (root node not found)\n");
  205. exit(-1);
  206. }
  207. root = iroot->second;
  208. }