FPCRandomForestTransfer.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. /**
  2. * @file FPCRandomForestTransfer.cpp
  3. * @brief implementation of random set forests
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #ifdef NOVISUAL
  8. #include <vislearning/nice_nonvis.h>
  9. #else
  10. #include <vislearning/nice.h>
  11. #endif
  12. #include <iostream>
  13. #include <list>
  14. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h"
  15. #include "vislearning/classifier/fpclassifier/randomforest/DTBStandard.h"
  16. #include "vislearning/classifier/fpclassifier/randomforest/DTBRandom.h"
  17. using namespace OBJREC;
  18. using namespace std;
  19. using namespace NICE;
  20. FPCRandomForestTransfer::FPCRandomForestTransfer( const Config *_conf,
  21. const ClassNames *classNames, std::string section ) :
  22. FPCRandomForests ( _conf, section ), dte ( _conf, section )
  23. {
  24. reduce_training_set = _conf->gB(section, "reduce_training_set", false);
  25. entropy_rejection_threshold = _conf->gD(section, "entropy_rejection_threshold", 0.0 );
  26. extend_only_critical_leafs = _conf->gB(section, "extend_only_critical_leafs", true );
  27. if ( reduce_training_set ) {
  28. training_absolute = _conf->gI ( section, "training_absolute", -1 );
  29. if ( training_absolute < 0 )
  30. training_ratio = _conf->gD ( section, "training_ratio" );
  31. }
  32. std::string substituteClasses_s = _conf->gS ( section, "substitute_classes" );
  33. classNames->getSelection ( substituteClasses_s, substituteClasses );
  34. std::string muClasses_s = _conf->gS ( section, "mu_classes" );
  35. classNames->getSelection ( muClasses_s, muClasses );
  36. std::string newClass_s = _conf->gS ( section, "new_classes" );
  37. classNames->getSelection ( newClass_s, newClass );
  38. sigmaq = _conf->gD ( section, "sigmaq" );
  39. cached_prior_structure = _conf->gS(section, "cached_prior_structure", "prior.tree" );
  40. read_cached_prior_structure = _conf->gB(section, "read_cached_prior_structure", false );
  41. if ( newClass.size() != 1 )
  42. {
  43. fprintf (stderr, "Multi-New-Class stuff not yet implemented\n");
  44. exit(-1);
  45. }
  46. partial_ml_estimation = _conf->gB(section, "partial_ml_estimation", false );
  47. partial_ml_estimation_depth = _conf->gI(section, "partial_ml_estimation_depth", 4 );
  48. extend_map_tree = _conf->gB(section, "extend_map_tree", false );
  49. if ( extend_map_tree )
  50. {
  51. std::string builder_e_method = _conf->gS(section, "builder_extend", "random" );
  52. std::string builder_e_section = _conf->gS(section, "builder_extend_section" );
  53. if ( builder_e_method == "standard" )
  54. builder_extend = new DTBStandard ( _conf, builder_e_section );
  55. else if (builder_e_method == "random" )
  56. builder_extend = new DTBRandom ( _conf, builder_e_section );
  57. else {
  58. fprintf (stderr, "DecisionTreeBuilder %s not yet implemented !\n",
  59. builder_e_method.c_str() );
  60. exit(-1);
  61. }
  62. }
  63. learn_ert_with_newclass = _conf->gB(section, "learn_ert_with_newclass", false);
  64. }
  65. FPCRandomForestTransfer::~FPCRandomForestTransfer()
  66. {
  67. }
  68. void FPCRandomForestTransfer::mlEstimate ( DecisionNode *node,
  69. Examples & examples_new,
  70. int newClassNo )
  71. {
  72. node->resetCounters();
  73. for ( Examples::iterator i = examples_new.begin() ;
  74. i != examples_new.end();
  75. i++ )
  76. {
  77. FullVector distribution (maxClassNo+1);
  78. assert ( i->first == newClassNo );
  79. node->traverse ( i->second, distribution );
  80. }
  81. map<DecisionNode *, pair<long, int> > index;
  82. long maxindex = 0;
  83. node->indexDescendants ( index, maxindex, 0 );
  84. for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
  85. i != index.end();
  86. i++ )
  87. {
  88. DecisionNode *node = i->first;
  89. node->distribution[newClassNo] = node->counter;
  90. }
  91. }
  92. void FPCRandomForestTransfer::partialMLEstimate ( DecisionTree & tree,
  93. Examples & examples_new,
  94. int newClassNo,
  95. int mldepth )
  96. {
  97. map<DecisionNode *, pair<long, int> > index;
  98. long maxindex = 0;
  99. tree.indexDescendants ( index, maxindex );
  100. for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
  101. i != index.end();
  102. i++ )
  103. {
  104. DecisionNode *node = i->first;
  105. pair<long, int> & data = i->second;
  106. int depth = data.second;
  107. if ( depth == mldepth ) {
  108. // I do not care whether this is a leaf node or not
  109. Examples examples_new_rw;
  110. examples_new_rw.insert ( examples_new_rw.begin(),
  111. examples_new.begin(),
  112. examples_new.end() );
  113. // reweight examples
  114. double weight = ( node->distribution.get ( newClassNo ) );
  115. if ( fabs(weight) < 10e-10 )
  116. {
  117. continue;
  118. }
  119. for ( Examples::iterator j = examples_new_rw.begin();
  120. j != examples_new_rw.end() ;
  121. j++ )
  122. {
  123. j->second.weight = weight / examples_new_rw.size();
  124. }
  125. mlEstimate ( node, examples_new_rw, newClassNo );
  126. }
  127. }
  128. }
  129. void FPCRandomForestTransfer::extendMapTree ( FeaturePool & fp,
  130. DecisionTree & tree,
  131. Examples & examples_transfer,
  132. Examples & examples_new,
  133. int newClassNo,
  134. const set<int> & muClasses )
  135. {
  136. map<DecisionNode *, set<int> > examplesTransferLeafs;
  137. fprintf (stderr, "FPCRandomForestTransfer: classify all %ld transfer examples\n",
  138. examples_transfer.size());
  139. int index = 0;
  140. for ( Examples::iterator i = examples_transfer.begin() ;
  141. i != examples_transfer.end();
  142. i++, index++ )
  143. {
  144. Example & pce = i->second;
  145. int example_classno = i->first;
  146. if ( (example_classno != newClassNo) &&
  147. (muClasses.find(example_classno) == muClasses.end() ) )
  148. continue;
  149. else
  150. fprintf (stderr, "suitable example of class %d found !\n", example_classno);
  151. DecisionNode *leaf = tree.getLeafNode ( pce );
  152. double weight = ( leaf->distribution.get ( newClassNo ) );
  153. if ( fabs(weight) < 10e-2 )
  154. continue;
  155. if ( extend_only_critical_leafs )
  156. {
  157. int maxClass = leaf->distribution.maxElement();
  158. if ( muClasses.find(maxClass) == muClasses.end() )
  159. continue;
  160. }
  161. double avgentropy = leaf->distribution.entropy() / log(leaf->distribution.size());
  162. if ( examplesTransferLeafs.find(leaf) == examplesTransferLeafs.end() )
  163. {
  164. /*fprintf (stderr, "FPCRandomForestTransfer: leaf owned by %d (normalized entropy %f)\n", maxClass, avgentropy );
  165. leaf->distribution.store(cerr); */
  166. }
  167. if ( avgentropy < entropy_rejection_threshold )
  168. {
  169. fprintf (stderr, "FPCRandomForestTransfer: leaf rejected due to entropy %f < %f!\n", avgentropy, entropy_rejection_threshold);
  170. continue;
  171. }
  172. examplesTransferLeafs[leaf].insert ( index );
  173. }
  174. fprintf (stderr, "FPCRandomForestTransfer: %ld leaf nodes will be extended\n",
  175. examplesTransferLeafs.size() );
  176. fprintf (stderr, "FPCRandomForestTransfer: Extending Leaf Nodes !\n");
  177. for ( map<DecisionNode *, set<int> >::iterator k = examplesTransferLeafs.begin();
  178. k != examplesTransferLeafs.end();
  179. k++ )
  180. {
  181. DecisionNode *node = k->first;
  182. FullVector examples_counts ( maxClassNo+1 );
  183. Examples examples_node;
  184. set<int> & examplesset = k->second;
  185. for ( set<int>::iterator i = examplesset.begin(); i != examplesset.end(); i++ )
  186. {
  187. pair<int, Example> & example = examples_transfer[ *i ];
  188. if ( node->distribution [ example.first ] < 10e-11 )
  189. continue;
  190. examples_node.push_back ( example );
  191. examples_counts[ example.first ]++;
  192. }
  193. fprintf (stderr, "FPCRandomForestTransfer: Examples from support classes %ld\n", examples_node.size() );
  194. fprintf (stderr, "FPCRandomForestTransfer: Examples from new class %ld (classno %d)\n", examples_new.size(),
  195. newClassNo);
  196. examples_node.insert ( examples_node.begin(), examples_new.begin(), examples_new.end() );
  197. examples_counts[newClassNo] = examples_new.size();
  198. fprintf (stderr, "FPCRandomForestTransfer: Extending leaf node with %ld examples\n", examples_node.size() );
  199. for ( Examples::iterator j = examples_node.begin();
  200. j != examples_node.end() ;
  201. j++ )
  202. {
  203. int classno = j->first;
  204. double weight = ( node->distribution.get ( classno ) );
  205. fprintf (stderr, "examples_counts[%d] = %f; weight %f\n", classno, examples_counts[classno], weight );
  206. j->second.weight = weight / examples_counts[classno];
  207. }
  208. DecisionNode *newnode = builder_extend->build ( fp, examples_node, maxClassNo );
  209. FullVector orig_distribution ( node->distribution );
  210. node->copy ( newnode );
  211. node->distribution = orig_distribution;
  212. orig_distribution.normalize();
  213. orig_distribution.store(cerr);
  214. double support_node_sum = 0.0;
  215. for ( int classi = 0 ; classi < node->distribution.size() ; classi++ )
  216. if ( (classi == newClassNo) || (muClasses.find(classi) != muClasses.end() ) )
  217. support_node_sum += node->distribution[classi];
  218. // set all probabilities for non support classes
  219. std::list<DecisionNode *> stack;
  220. stack.push_back ( node );
  221. while ( stack.size() > 0 )
  222. {
  223. DecisionNode *cnode = stack.front();
  224. stack.pop_front();
  225. double cnode_sum = 0.0;
  226. for ( int classi = 0 ; classi < cnode->distribution.size() ; classi++ )
  227. if ( (classi != newClassNo) && (muClasses.find(classi) == muClasses.end() ) )
  228. cnode->distribution[classi] = node->distribution[classi];
  229. else
  230. cnode_sum += cnode->distribution[classi];
  231. if ( fabs(cnode_sum) > 10e-11 )
  232. for ( int classi = 0 ; classi < node->distribution.size() ; classi++ )
  233. if ( (classi == newClassNo) || (muClasses.find(classi) != muClasses.end() ) )
  234. cnode->distribution[classi] *= support_node_sum / cnode_sum;
  235. if ( (cnode->left == NULL) && (cnode->right == NULL ) )
  236. {
  237. FullVector stuff ( cnode->distribution );
  238. stuff.normalize();
  239. stuff.store(cerr);
  240. }
  241. if ( cnode->left != NULL )
  242. stack.push_back ( cnode->left );
  243. if ( cnode->right != NULL )
  244. stack.push_back ( cnode->right );
  245. }
  246. }
  247. fprintf (stderr, "FPCRandomForestTransfer: MAP tree extension done !\n");
  248. }
  249. void FPCRandomForestTransfer::train ( FeaturePool & fp,
  250. Examples & examples )
  251. {
  252. maxClassNo = examples.getMaxClassNo();
  253. fprintf (stderr, "############### FPCRandomForestTransfer::train ####################\n");
  254. assert ( newClass.size() == 1 );
  255. int newClassNo = *(newClass.begin());
  256. // reduce training set
  257. Examples examples_new;
  258. Examples examples_transfer;
  259. for ( Examples::const_iterator i = examples.begin();
  260. i != examples.end();
  261. i++ )
  262. {
  263. int classno = i->first;
  264. if ( newClass.find(classno) != newClass.end() ) {
  265. examples_new.push_back ( *i );
  266. } else {
  267. examples_transfer.push_back ( *i );
  268. }
  269. }
  270. if ( examples_new.size() <= 0 )
  271. {
  272. if ( newClass.size() <= 0 ) {
  273. fprintf (stderr, "FPCRandomForestTransfer::train: no new classes given !\n");
  274. } else {
  275. fprintf (stderr, "FPCRandomForestTransfer::train: no examples found of class %d\n", newClassNo );
  276. }
  277. exit(-1);
  278. }
  279. if ( reduce_training_set )
  280. {
  281. // reduce training set
  282. random_shuffle ( examples_new.begin(), examples_new.end() );
  283. int oldsize = (int)examples_new.size();
  284. int newsize;
  285. if ( training_absolute < 0 )
  286. newsize = (int)(training_ratio*examples_new.size());
  287. else
  288. newsize = training_absolute;
  289. Examples::iterator j = examples_new.begin() + newsize;
  290. examples_new.erase ( j, examples_new.end() );
  291. fprintf (stderr, "Size of training set randomly reduced from %d to %d\n", oldsize,
  292. (int)examples_new.size() );
  293. }
  294. if ( read_cached_prior_structure )
  295. {
  296. FPCRandomForests::read ( cached_prior_structure );
  297. } else {
  298. if ( learn_ert_with_newclass )
  299. {
  300. FPCRandomForests::train ( fp, examples );
  301. } else {
  302. FPCRandomForests::train ( fp, examples_transfer );
  303. }
  304. FPCRandomForests::save ( cached_prior_structure );
  305. }
  306. fprintf (stderr, "MAP ESTIMATION sigmaq = %e\n", sigmaq);
  307. for ( vector<DecisionTree *>::iterator i = forest.begin();
  308. i != forest.end();
  309. i++ )
  310. {
  311. DecisionTree & tree = *(*i);
  312. dte.reestimate ( tree,
  313. examples_new,
  314. sigmaq,
  315. newClassNo,
  316. muClasses,
  317. substituteClasses,
  318. maxClassNo);
  319. if ( partial_ml_estimation )
  320. {
  321. partialMLEstimate ( tree,
  322. examples_new,
  323. newClassNo,
  324. partial_ml_estimation_depth );
  325. }
  326. if ( extend_map_tree )
  327. {
  328. fp.initRandomFeatureSelection ();
  329. extendMapTree ( fp,
  330. tree,
  331. examples_transfer,
  332. examples_new,
  333. newClassNo,
  334. muClasses);
  335. }
  336. }
  337. save ( "map.tree" );
  338. }
  339. FeaturePoolClassifier *FPCRandomForestTransfer::clone () const
  340. {
  341. fprintf (stderr, "FPCRandomForestTransfer::clone() not yet implemented !\n");
  342. exit(-1);
  343. }