DecisionTree.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. /**
  2. * @file DecisionTree.cpp
  3. * @brief decision tree implementation
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #ifdef NOVISUAL
  8. #include <vislearning/nice_nonvis.h>
  9. #else
  10. #include <vislearning/nice.h>
  11. #endif
  12. #include <iostream>
  13. #include "vislearning/classifier/fpclassifier/randomforest/DecisionTree.h"
  14. #include "vislearning/features/fpfeatures/createFeatures.h"
  15. using namespace OBJREC;
  16. using namespace std;
  17. // refactor-nice.pl: check this substitution
  18. // old: using namespace ice;
  19. using namespace NICE;
  20. DecisionTree::DecisionTree( const Config *_conf, int _maxClassNo ) : conf(_conf)
  21. {
  22. root = NULL;
  23. maxClassNo = _maxClassNo;
  24. }
  25. DecisionTree::~DecisionTree()
  26. {
  27. deleteNodes ( root );
  28. }
  29. void DecisionTree::statistics ( int & depth, int & count ) const
  30. {
  31. if ( root == NULL )
  32. {
  33. depth = 0;
  34. count = 0;
  35. } else {
  36. root->statistics ( depth, count );
  37. }
  38. }
  39. void DecisionTree::traverse (
  40. const Example & ce,
  41. FullVector & distribution )
  42. {
  43. assert(root != NULL);
  44. root->traverse ( ce, distribution );
  45. }
  46. void DecisionTree::deleteNodes ( DecisionNode *tree )
  47. {
  48. if ( tree != NULL )
  49. {
  50. deleteNodes ( tree->left );
  51. deleteNodes ( tree->right );
  52. delete tree;
  53. }
  54. }
  55. void DecisionTree::restore (istream & is, int format)
  56. {
  57. // indexing
  58. map<long, DecisionNode *> index;
  59. map<long, pair<long, long> > descendants;
  60. index.insert ( pair<long, DecisionNode *> ( 0, NULL ) );
  61. // refactor-nice.pl: check this substitution
  62. // old: string tag;
  63. std::string tag;
  64. while ( (! is.eof()) && ( (is >> tag) && (tag == "NODE") ) )
  65. {
  66. long ind;
  67. long ind_l;
  68. long ind_r;
  69. if (! (is >> ind)) break;
  70. if (! (is >> ind_l)) break;
  71. if (! (is >> ind_r)) break;
  72. descendants.insert ( pair<long, pair<long, long> > ( ind, pair<long, long> ( ind_l, ind_r ) ) );
  73. DecisionNode *node = new DecisionNode();
  74. index.insert ( pair<long, DecisionNode *> ( ind, node ) );
  75. std::string feature_tag;
  76. is >> feature_tag;
  77. if ( feature_tag != "LEAF" )
  78. {
  79. node->f = createFeatureFromTag ( conf, feature_tag );
  80. if ( node->f == NULL )
  81. {
  82. fprintf (stderr, "Unknown feature description: %s\n",
  83. feature_tag.c_str() );
  84. exit(-1);
  85. }
  86. node->f->restore ( is, format );
  87. is >> node->threshold;
  88. }
  89. FullVector distribution ( maxClassNo+1 );
  90. int classno;
  91. double score;
  92. //distribution.restore ( is );
  93. is >> classno;
  94. while ( classno >= 0 )
  95. {
  96. is >> score;
  97. if ( classno > maxClassNo )
  98. {
  99. fprintf (stderr, "classno: %d; maxClassNo: %d\n", classno, maxClassNo);
  100. exit(-1);
  101. }
  102. distribution[classno] = score;
  103. is >> classno;
  104. }
  105. //distribution.store(cerr);
  106. node->distribution = distribution;
  107. }
  108. // connecting the tree
  109. for ( map<long, DecisionNode *>::const_iterator i = index.begin();
  110. i != index.end(); i++ )
  111. {
  112. DecisionNode *node = i->second;
  113. if ( node == NULL ) continue;
  114. long ind_l = descendants[i->first].first;
  115. long ind_r = descendants[i->first].second;
  116. map<long, DecisionNode *>::const_iterator il = index.find ( ind_l );
  117. map<long, DecisionNode *>::const_iterator ir = index.find ( ind_r );
  118. if ( ( il == index.end() ) || ( ir == index.end() ) )
  119. {
  120. fprintf (stderr, "File inconsistent: unable to build tree\n");
  121. exit(-1);
  122. }
  123. DecisionNode *left = il->second;
  124. DecisionNode *right = ir->second;
  125. node->left = left;
  126. node->right = right;
  127. }
  128. map<long, DecisionNode *>::const_iterator iroot = index.find ( 1 );
  129. if ( iroot == index.end() )
  130. {
  131. fprintf (stderr, "File inconsistent: unable to build tree (root node not found)\n");
  132. exit(-1);
  133. }
  134. root = iroot->second;
  135. }
  136. void DecisionTree::store (ostream & os, int format) const
  137. {
  138. if ( root == NULL ) return;
  139. // indexing
  140. map<DecisionNode *, pair<long, int> > index;
  141. index.insert ( pair<DecisionNode *, pair<long, int> > ( NULL, pair<long, int> ( 0, 0 ) ) );
  142. index.insert ( pair<DecisionNode *, pair<long, int> > ( root, pair<long, int> ( 1, 0 ) ) );
  143. long maxindex = 1;
  144. root->indexDescendants ( index, maxindex, 0 );
  145. for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
  146. i != index.end();
  147. i++ )
  148. {
  149. DecisionNode *node = i->first;
  150. if ( node == NULL ) continue;
  151. long ind = i->second.first;
  152. long ind_l = index[ node->left ].first;
  153. long ind_r = index[ node->right ].first;
  154. os << "NODE " << ind << " " << ind_l << " " << ind_r << endl;
  155. Feature *f = node->f;
  156. if ( f != NULL ) {
  157. f->store ( os, format );
  158. os << endl;
  159. os << node->threshold;
  160. os << endl;
  161. } else {
  162. os << "LEAF";
  163. os << endl;
  164. }
  165. const FullVector & distribution = node->distribution;
  166. for ( int i = 0 ; i < distribution.size() ; i++ )
  167. os << i << " " << distribution[i] << " ";
  168. os << -1 << endl;
  169. //distribution.store ( os );
  170. }
  171. }
  172. void DecisionTree::clear ()
  173. {
  174. deleteNodes ( root );
  175. }
  176. void DecisionTree::resetCounters ()
  177. {
  178. if ( root != NULL )
  179. root->resetCounters ();
  180. }
  181. void DecisionTree::indexDescendants ( map<DecisionNode *, pair<long, int> > & index, long & maxindex ) const
  182. {
  183. if ( root != NULL )
  184. root->indexDescendants ( index, maxindex, 0 );
  185. }
  186. DecisionNode *DecisionTree::getLeafNode ( Example & pce, int maxdepth )
  187. {
  188. return root->getLeafNode ( pce, maxdepth );
  189. }
  190. void DecisionTree::getLeaves(DecisionNode *node, vector<DecisionNode*> &leaves)
  191. {
  192. if(node->left == NULL && node->right == NULL)
  193. {
  194. leaves.push_back(node);
  195. return;
  196. }
  197. getLeaves(node->right, leaves);
  198. getLeaves(node->left, leaves);
  199. }
  200. vector<DecisionNode *> DecisionTree::getAllLeafNodes()
  201. {
  202. vector<DecisionNode*> leaves;
  203. getLeaves(root, leaves);
  204. return leaves;
  205. }
  206. DecisionNode *DecisionTree::pruneTreeEntropy ( DecisionNode *node, double minEntropy )
  207. {
  208. if ( node == NULL ) return NULL;
  209. double entropy = node->distribution.entropy();
  210. if ( entropy < minEntropy )
  211. {
  212. deleteNodes ( node );
  213. return NULL;
  214. } else {
  215. node->left = pruneTreeEntropy ( node->left, minEntropy );
  216. node->right = pruneTreeEntropy ( node->right, minEntropy );
  217. return node;
  218. }
  219. }
  220. DecisionNode *DecisionTree::pruneTreeScore ( DecisionNode *node, double minScore )
  221. {
  222. if ( node == NULL ) return NULL;
  223. double score = node->distribution.max();
  224. if ( score < minScore )
  225. {
  226. deleteNodes ( node );
  227. return NULL;
  228. } else {
  229. node->left = pruneTreeScore ( node->left, minScore );
  230. node->right = pruneTreeScore ( node->right, minScore );
  231. return node;
  232. }
  233. }
  234. void DecisionTree::pruneTreeScore ( double minScore )
  235. {
  236. int depth, count;
  237. statistics ( depth, count );
  238. fprintf (stderr, "DecisionTree::pruneTreeScore: depth %d count %d\n", depth, count );
  239. root = pruneTreeScore ( root, minScore );
  240. statistics ( depth, count );
  241. fprintf (stderr, "DecisionTree::pruneTreeScore: depth %d count %d (modified)\n", depth, count );
  242. }
  243. void DecisionTree::pruneTreeEntropy ( double minEntropy )
  244. {
  245. int depth, count;
  246. statistics ( depth, count );
  247. fprintf (stderr, "DecisionTree::entropyTreeScore: depth %d count %d\n", depth, count );
  248. root = pruneTreeEntropy ( root, minEntropy );
  249. statistics ( depth, count );
  250. fprintf (stderr, "DecisionTree::entropyTreeScore: depth %d count %d (modified)\n", depth, count );
  251. }
  252. void DecisionTree::normalize (DecisionNode *node)
  253. {
  254. if ( node != NULL )
  255. {
  256. node->distribution.normalize();
  257. normalize ( node->left );
  258. normalize ( node->right );
  259. }
  260. }
  261. void DecisionTree::normalize ()
  262. {
  263. normalize ( root );
  264. }
  265. void DecisionTree::setRoot ( DecisionNode *newroot )
  266. {
  267. root = newroot;
  268. }