DecisionTree.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. /**
  2. * @file DecisionTree.cpp
  3. * @brief decision tree implementation
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #include <iostream>
  8. #include "vislearning/classifier/fpclassifier/randomforest/DecisionTree.h"
  9. #include "vislearning/features/fpfeatures/createFeatures.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. // refactor-nice.pl: check this substitution
  13. // old: using namespace ice;
  14. using namespace NICE;
  15. DecisionTree::DecisionTree( const Config *_conf, int _maxClassNo ) : conf(_conf)
  16. {
  17. root = NULL;
  18. maxClassNo = _maxClassNo;
  19. }
  20. DecisionTree::~DecisionTree()
  21. {
  22. deleteNodes ( root );
  23. }
  24. void DecisionTree::statistics ( int & depth, int & count ) const
  25. {
  26. if ( root == NULL )
  27. {
  28. depth = 0;
  29. count = 0;
  30. } else {
  31. root->statistics ( depth, count );
  32. }
  33. }
  34. void DecisionTree::traverse (
  35. const Example & ce,
  36. FullVector & distribution )
  37. {
  38. assert(root != NULL);
  39. root->traverse ( ce, distribution );
  40. }
  41. void DecisionTree::deleteNodes ( DecisionNode *tree )
  42. {
  43. if ( tree != NULL )
  44. {
  45. deleteNodes ( tree->left );
  46. deleteNodes ( tree->right );
  47. delete tree;
  48. }
  49. }
  50. void DecisionTree::restore (istream & is, int format)
  51. {
  52. // indexing
  53. map<long, DecisionNode *> index;
  54. map<long, pair<long, long> > descendants;
  55. index.insert ( pair<long, DecisionNode *> ( 0, NULL ) );
  56. // refactor-nice.pl: check this substitution
  57. // old: string tag;
  58. std::string tag;
  59. while ( (! is.eof()) && ( (is >> tag) && (tag == "NODE") ) )
  60. {
  61. long ind;
  62. long ind_l;
  63. long ind_r;
  64. if (! (is >> ind)) break;
  65. if (! (is >> ind_l)) break;
  66. if (! (is >> ind_r)) break;
  67. descendants.insert ( pair<long, pair<long, long> > ( ind, pair<long, long> ( ind_l, ind_r ) ) );
  68. DecisionNode *node = new DecisionNode();
  69. index.insert ( pair<long, DecisionNode *> ( ind, node ) );
  70. std::string feature_tag;
  71. is >> feature_tag;
  72. if ( feature_tag != "LEAF" )
  73. {
  74. node->f = createFeatureFromTag ( conf, feature_tag );
  75. if ( node->f == NULL )
  76. {
  77. fprintf (stderr, "Unknown feature description: %s\n",
  78. feature_tag.c_str() );
  79. exit(-1);
  80. }
  81. node->f->restore ( is, format );
  82. is >> node->threshold;
  83. }
  84. FullVector distribution ( maxClassNo+1 );
  85. int classno;
  86. double score;
  87. //distribution.restore ( is );
  88. is >> classno;
  89. while ( classno >= 0 )
  90. {
  91. is >> score;
  92. if ( classno > maxClassNo )
  93. {
  94. fprintf (stderr, "classno: %d; maxClassNo: %d\n", classno, maxClassNo);
  95. exit(-1);
  96. }
  97. distribution[classno] = score;
  98. is >> classno;
  99. }
  100. //distribution.store(cerr);
  101. node->distribution = distribution;
  102. }
  103. // connecting the tree
  104. for ( map<long, DecisionNode *>::const_iterator i = index.begin();
  105. i != index.end(); i++ )
  106. {
  107. DecisionNode *node = i->second;
  108. if ( node == NULL ) continue;
  109. long ind_l = descendants[i->first].first;
  110. long ind_r = descendants[i->first].second;
  111. map<long, DecisionNode *>::const_iterator il = index.find ( ind_l );
  112. map<long, DecisionNode *>::const_iterator ir = index.find ( ind_r );
  113. if ( ( il == index.end() ) || ( ir == index.end() ) )
  114. {
  115. fprintf (stderr, "File inconsistent: unable to build tree\n");
  116. exit(-1);
  117. }
  118. DecisionNode *left = il->second;
  119. DecisionNode *right = ir->second;
  120. node->left = left;
  121. node->right = right;
  122. }
  123. map<long, DecisionNode *>::const_iterator iroot = index.find ( 1 );
  124. if ( iroot == index.end() )
  125. {
  126. fprintf (stderr, "File inconsistent: unable to build tree (root node not found)\n");
  127. exit(-1);
  128. }
  129. root = iroot->second;
  130. }
  131. void DecisionTree::store (ostream & os, int format) const
  132. {
  133. if ( root == NULL ) return;
  134. // indexing
  135. map<DecisionNode *, pair<long, int> > index;
  136. index.insert ( pair<DecisionNode *, pair<long, int> > ( NULL, pair<long, int> ( 0, 0 ) ) );
  137. index.insert ( pair<DecisionNode *, pair<long, int> > ( root, pair<long, int> ( 1, 0 ) ) );
  138. long maxindex = 1;
  139. root->indexDescendants ( index, maxindex, 0 );
  140. for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
  141. i != index.end();
  142. i++ )
  143. {
  144. DecisionNode *node = i->first;
  145. if ( node == NULL ) continue;
  146. long ind = i->second.first;
  147. long ind_l = index[ node->left ].first;
  148. long ind_r = index[ node->right ].first;
  149. os << "NODE " << ind << " " << ind_l << " " << ind_r << endl;
  150. Feature *f = node->f;
  151. if ( f != NULL ) {
  152. f->store ( os, format );
  153. os << endl;
  154. os << node->threshold;
  155. os << endl;
  156. } else {
  157. os << "LEAF";
  158. os << endl;
  159. }
  160. const FullVector & distribution = node->distribution;
  161. for ( int i = 0 ; i < distribution.size() ; i++ )
  162. os << i << " " << distribution[i] << " ";
  163. os << -1 << endl;
  164. //distribution.store ( os );
  165. }
  166. }
  167. void DecisionTree::clear ()
  168. {
  169. deleteNodes ( root );
  170. }
  171. void DecisionTree::resetCounters ()
  172. {
  173. if ( root != NULL )
  174. root->resetCounters ();
  175. }
  176. void DecisionTree::indexDescendants ( map<DecisionNode *, pair<long, int> > & index, long & maxindex ) const
  177. {
  178. if ( root != NULL )
  179. root->indexDescendants ( index, maxindex, 0 );
  180. }
  181. DecisionNode *DecisionTree::getLeafNode ( Example & pce, int maxdepth )
  182. {
  183. return root->getLeafNode ( pce, maxdepth );
  184. }
  185. void DecisionTree::getLeaves(DecisionNode *node, vector<DecisionNode*> &leaves)
  186. {
  187. if(node->left == NULL && node->right == NULL)
  188. {
  189. leaves.push_back(node);
  190. return;
  191. }
  192. getLeaves(node->right, leaves);
  193. getLeaves(node->left, leaves);
  194. }
  195. vector<DecisionNode *> DecisionTree::getAllLeafNodes()
  196. {
  197. vector<DecisionNode*> leaves;
  198. getLeaves(root, leaves);
  199. return leaves;
  200. }
  201. DecisionNode *DecisionTree::pruneTreeEntropy ( DecisionNode *node, double minEntropy )
  202. {
  203. if ( node == NULL ) return NULL;
  204. double entropy = node->distribution.entropy();
  205. if ( entropy < minEntropy )
  206. {
  207. deleteNodes ( node );
  208. return NULL;
  209. } else {
  210. node->left = pruneTreeEntropy ( node->left, minEntropy );
  211. node->right = pruneTreeEntropy ( node->right, minEntropy );
  212. return node;
  213. }
  214. }
  215. DecisionNode *DecisionTree::pruneTreeScore ( DecisionNode *node, double minScore )
  216. {
  217. if ( node == NULL ) return NULL;
  218. double score = node->distribution.max();
  219. if ( score < minScore )
  220. {
  221. deleteNodes ( node );
  222. return NULL;
  223. } else {
  224. node->left = pruneTreeScore ( node->left, minScore );
  225. node->right = pruneTreeScore ( node->right, minScore );
  226. return node;
  227. }
  228. }
  229. void DecisionTree::pruneTreeScore ( double minScore )
  230. {
  231. int depth, count;
  232. statistics ( depth, count );
  233. fprintf (stderr, "DecisionTree::pruneTreeScore: depth %d count %d\n", depth, count );
  234. root = pruneTreeScore ( root, minScore );
  235. statistics ( depth, count );
  236. fprintf (stderr, "DecisionTree::pruneTreeScore: depth %d count %d (modified)\n", depth, count );
  237. }
  238. void DecisionTree::pruneTreeEntropy ( double minEntropy )
  239. {
  240. int depth, count;
  241. statistics ( depth, count );
  242. fprintf (stderr, "DecisionTree::entropyTreeScore: depth %d count %d\n", depth, count );
  243. root = pruneTreeEntropy ( root, minEntropy );
  244. statistics ( depth, count );
  245. fprintf (stderr, "DecisionTree::entropyTreeScore: depth %d count %d (modified)\n", depth, count );
  246. }
  247. void DecisionTree::normalize (DecisionNode *node)
  248. {
  249. if ( node != NULL )
  250. {
  251. node->distribution.normalize();
  252. normalize ( node->left );
  253. normalize ( node->right );
  254. }
  255. }
  256. void DecisionTree::normalize ()
  257. {
  258. normalize ( root );
  259. }
  260. void DecisionTree::setRoot ( DecisionNode *newroot )
  261. {
  262. root = newroot;
  263. }