FPCRandomForestTransfer.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. /**
  2. * @file FPCRandomForestTransfer.cpp
  3. * @brief implementation of random set forests
  4. * @author Erik Rodner
  5. * @date 04/24/2008
  6. */
  7. #include <iostream>
  8. #include <list>
  9. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/DTBStandard.h"
  11. #include "vislearning/classifier/fpclassifier/randomforest/DTBRandom.h"
  12. using namespace OBJREC;
  13. using namespace std;
  14. using namespace NICE;
  15. FPCRandomForestTransfer::FPCRandomForestTransfer( const Config *_conf,
  16. const ClassNames *classNames, std::string section ) :
  17. FPCRandomForests ( _conf, section ), dte ( _conf, section )
  18. {
  19. reduce_training_set = _conf->gB(section, "reduce_training_set", false);
  20. entropy_rejection_threshold = _conf->gD(section, "entropy_rejection_threshold", 0.0 );
  21. extend_only_critical_leafs = _conf->gB(section, "extend_only_critical_leafs", true );
  22. if ( reduce_training_set ) {
  23. training_absolute = _conf->gI ( section, "training_absolute", -1 );
  24. if ( training_absolute < 0 )
  25. training_ratio = _conf->gD ( section, "training_ratio" );
  26. }
  27. std::string substituteClasses_s = _conf->gS ( section, "substitute_classes" );
  28. classNames->getSelection ( substituteClasses_s, substituteClasses );
  29. std::string muClasses_s = _conf->gS ( section, "mu_classes" );
  30. classNames->getSelection ( muClasses_s, muClasses );
  31. std::string newClass_s = _conf->gS ( section, "new_classes" );
  32. classNames->getSelection ( newClass_s, newClass );
  33. sigmaq = _conf->gD ( section, "sigmaq" );
  34. cached_prior_structure = _conf->gS(section, "cached_prior_structure", "prior.tree" );
  35. read_cached_prior_structure = _conf->gB(section, "read_cached_prior_structure", false );
  36. if ( newClass.size() != 1 )
  37. {
  38. fprintf (stderr, "Multi-New-Class stuff not yet implemented\n");
  39. exit(-1);
  40. }
  41. partial_ml_estimation = _conf->gB(section, "partial_ml_estimation", false );
  42. partial_ml_estimation_depth = _conf->gI(section, "partial_ml_estimation_depth", 4 );
  43. extend_map_tree = _conf->gB(section, "extend_map_tree", false );
  44. if ( extend_map_tree )
  45. {
  46. std::string builder_e_method = _conf->gS(section, "builder_extend", "random" );
  47. std::string builder_e_section = _conf->gS(section, "builder_extend_section" );
  48. if ( builder_e_method == "standard" )
  49. builder_extend = new DTBStandard ( _conf, builder_e_section );
  50. else if (builder_e_method == "random" )
  51. builder_extend = new DTBRandom ( _conf, builder_e_section );
  52. else {
  53. fprintf (stderr, "DecisionTreeBuilder %s not yet implemented !\n",
  54. builder_e_method.c_str() );
  55. exit(-1);
  56. }
  57. }
  58. learn_ert_with_newclass = _conf->gB(section, "learn_ert_with_newclass", false);
  59. }
  60. FPCRandomForestTransfer::~FPCRandomForestTransfer()
  61. {
  62. }
  63. void FPCRandomForestTransfer::mlEstimate ( DecisionNode *node,
  64. Examples & examples_new,
  65. int newClassNo )
  66. {
  67. node->resetCounters();
  68. for ( Examples::iterator i = examples_new.begin() ;
  69. i != examples_new.end();
  70. i++ )
  71. {
  72. FullVector distribution (maxClassNo+1);
  73. assert ( i->first == newClassNo );
  74. node->traverse ( i->second, distribution );
  75. }
  76. map<DecisionNode *, pair<long, int> > index;
  77. long maxindex = 0;
  78. node->indexDescendants ( index, maxindex, 0 );
  79. for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
  80. i != index.end();
  81. i++ )
  82. {
  83. DecisionNode *node = i->first;
  84. node->distribution[newClassNo] = node->counter;
  85. }
  86. }
  87. void FPCRandomForestTransfer::partialMLEstimate ( DecisionTree & tree,
  88. Examples & examples_new,
  89. int newClassNo,
  90. int mldepth )
  91. {
  92. map<DecisionNode *, pair<long, int> > index;
  93. long maxindex = 0;
  94. tree.indexDescendants ( index, maxindex );
  95. for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
  96. i != index.end();
  97. i++ )
  98. {
  99. DecisionNode *node = i->first;
  100. pair<long, int> & data = i->second;
  101. int depth = data.second;
  102. if ( depth == mldepth ) {
  103. // I do not care whether this is a leaf node or not
  104. Examples examples_new_rw;
  105. examples_new_rw.insert ( examples_new_rw.begin(),
  106. examples_new.begin(),
  107. examples_new.end() );
  108. // reweight examples
  109. double weight = ( node->distribution.get ( newClassNo ) );
  110. if ( fabs(weight) < 10e-10 )
  111. {
  112. continue;
  113. }
  114. for ( Examples::iterator j = examples_new_rw.begin();
  115. j != examples_new_rw.end() ;
  116. j++ )
  117. {
  118. j->second.weight = weight / examples_new_rw.size();
  119. }
  120. mlEstimate ( node, examples_new_rw, newClassNo );
  121. }
  122. }
  123. }
  124. void FPCRandomForestTransfer::extendMapTree ( FeaturePool & fp,
  125. DecisionTree & tree,
  126. Examples & examples_transfer,
  127. Examples & examples_new,
  128. int newClassNo,
  129. const set<int> & muClasses )
  130. {
  131. map<DecisionNode *, set<int> > examplesTransferLeafs;
  132. fprintf (stderr, "FPCRandomForestTransfer: classify all %ld transfer examples\n",
  133. examples_transfer.size());
  134. int index = 0;
  135. for ( Examples::iterator i = examples_transfer.begin() ;
  136. i != examples_transfer.end();
  137. i++, index++ )
  138. {
  139. Example & pce = i->second;
  140. int example_classno = i->first;
  141. if ( (example_classno != newClassNo) &&
  142. (muClasses.find(example_classno) == muClasses.end() ) )
  143. continue;
  144. else
  145. fprintf (stderr, "suitable example of class %d found !\n", example_classno);
  146. DecisionNode *leaf = tree.getLeafNode ( pce );
  147. double weight = ( leaf->distribution.get ( newClassNo ) );
  148. if ( fabs(weight) < 10e-2 )
  149. continue;
  150. if ( extend_only_critical_leafs )
  151. {
  152. int maxClass = leaf->distribution.maxElement();
  153. if ( muClasses.find(maxClass) == muClasses.end() )
  154. continue;
  155. }
  156. double avgentropy = leaf->distribution.entropy() / log(leaf->distribution.size());
  157. if ( examplesTransferLeafs.find(leaf) == examplesTransferLeafs.end() )
  158. {
  159. /*fprintf (stderr, "FPCRandomForestTransfer: leaf owned by %d (normalized entropy %f)\n", maxClass, avgentropy );
  160. leaf->distribution.store(cerr); */
  161. }
  162. if ( avgentropy < entropy_rejection_threshold )
  163. {
  164. fprintf (stderr, "FPCRandomForestTransfer: leaf rejected due to entropy %f < %f!\n", avgentropy, entropy_rejection_threshold);
  165. continue;
  166. }
  167. examplesTransferLeafs[leaf].insert ( index );
  168. }
  169. fprintf (stderr, "FPCRandomForestTransfer: %ld leaf nodes will be extended\n",
  170. examplesTransferLeafs.size() );
  171. fprintf (stderr, "FPCRandomForestTransfer: Extending Leaf Nodes !\n");
  172. for ( map<DecisionNode *, set<int> >::iterator k = examplesTransferLeafs.begin();
  173. k != examplesTransferLeafs.end();
  174. k++ )
  175. {
  176. DecisionNode *node = k->first;
  177. FullVector examples_counts ( maxClassNo+1 );
  178. Examples examples_node;
  179. set<int> & examplesset = k->second;
  180. for ( set<int>::iterator i = examplesset.begin(); i != examplesset.end(); i++ )
  181. {
  182. pair<int, Example> & example = examples_transfer[ *i ];
  183. if ( node->distribution [ example.first ] < 10e-11 )
  184. continue;
  185. examples_node.push_back ( example );
  186. examples_counts[ example.first ]++;
  187. }
  188. fprintf (stderr, "FPCRandomForestTransfer: Examples from support classes %ld\n", examples_node.size() );
  189. fprintf (stderr, "FPCRandomForestTransfer: Examples from new class %ld (classno %d)\n", examples_new.size(),
  190. newClassNo);
  191. examples_node.insert ( examples_node.begin(), examples_new.begin(), examples_new.end() );
  192. examples_counts[newClassNo] = examples_new.size();
  193. fprintf (stderr, "FPCRandomForestTransfer: Extending leaf node with %ld examples\n", examples_node.size() );
  194. for ( Examples::iterator j = examples_node.begin();
  195. j != examples_node.end() ;
  196. j++ )
  197. {
  198. int classno = j->first;
  199. double weight = ( node->distribution.get ( classno ) );
  200. fprintf (stderr, "examples_counts[%d] = %f; weight %f\n", classno, examples_counts[classno], weight );
  201. j->second.weight = weight / examples_counts[classno];
  202. }
  203. DecisionNode *newnode = builder_extend->build ( fp, examples_node, maxClassNo );
  204. FullVector orig_distribution ( node->distribution );
  205. node->copy ( newnode );
  206. node->distribution = orig_distribution;
  207. orig_distribution.normalize();
  208. orig_distribution.store(cerr);
  209. double support_node_sum = 0.0;
  210. for ( int classi = 0 ; classi < node->distribution.size() ; classi++ )
  211. if ( (classi == newClassNo) || (muClasses.find(classi) != muClasses.end() ) )
  212. support_node_sum += node->distribution[classi];
  213. // set all probabilities for non support classes
  214. std::list<DecisionNode *> stack;
  215. stack.push_back ( node );
  216. while ( stack.size() > 0 )
  217. {
  218. DecisionNode *cnode = stack.front();
  219. stack.pop_front();
  220. double cnode_sum = 0.0;
  221. for ( int classi = 0 ; classi < cnode->distribution.size() ; classi++ )
  222. if ( (classi != newClassNo) && (muClasses.find(classi) == muClasses.end() ) )
  223. cnode->distribution[classi] = node->distribution[classi];
  224. else
  225. cnode_sum += cnode->distribution[classi];
  226. if ( fabs(cnode_sum) > 10e-11 )
  227. for ( int classi = 0 ; classi < node->distribution.size() ; classi++ )
  228. if ( (classi == newClassNo) || (muClasses.find(classi) != muClasses.end() ) )
  229. cnode->distribution[classi] *= support_node_sum / cnode_sum;
  230. if ( (cnode->left == NULL) && (cnode->right == NULL ) )
  231. {
  232. FullVector stuff ( cnode->distribution );
  233. stuff.normalize();
  234. stuff.store(cerr);
  235. }
  236. if ( cnode->left != NULL )
  237. stack.push_back ( cnode->left );
  238. if ( cnode->right != NULL )
  239. stack.push_back ( cnode->right );
  240. }
  241. }
  242. fprintf (stderr, "FPCRandomForestTransfer: MAP tree extension done !\n");
  243. }
  244. void FPCRandomForestTransfer::train ( FeaturePool & fp,
  245. Examples & examples )
  246. {
  247. maxClassNo = examples.getMaxClassNo();
  248. fprintf (stderr, "############### FPCRandomForestTransfer::train ####################\n");
  249. assert ( newClass.size() == 1 );
  250. int newClassNo = *(newClass.begin());
  251. // reduce training set
  252. Examples examples_new;
  253. Examples examples_transfer;
  254. for ( Examples::const_iterator i = examples.begin();
  255. i != examples.end();
  256. i++ )
  257. {
  258. int classno = i->first;
  259. if ( newClass.find(classno) != newClass.end() ) {
  260. examples_new.push_back ( *i );
  261. } else {
  262. examples_transfer.push_back ( *i );
  263. }
  264. }
  265. if ( examples_new.size() <= 0 )
  266. {
  267. if ( newClass.size() <= 0 ) {
  268. fprintf (stderr, "FPCRandomForestTransfer::train: no new classes given !\n");
  269. } else {
  270. fprintf (stderr, "FPCRandomForestTransfer::train: no examples found of class %d\n", newClassNo );
  271. }
  272. exit(-1);
  273. }
  274. if ( reduce_training_set )
  275. {
  276. // reduce training set
  277. random_shuffle ( examples_new.begin(), examples_new.end() );
  278. int oldsize = (int)examples_new.size();
  279. int newsize;
  280. if ( training_absolute < 0 )
  281. newsize = (int)(training_ratio*examples_new.size());
  282. else
  283. newsize = training_absolute;
  284. Examples::iterator j = examples_new.begin() + newsize;
  285. examples_new.erase ( j, examples_new.end() );
  286. fprintf (stderr, "Size of training set randomly reduced from %d to %d\n", oldsize,
  287. (int)examples_new.size() );
  288. }
  289. if ( read_cached_prior_structure )
  290. {
  291. FPCRandomForests::read ( cached_prior_structure );
  292. } else {
  293. if ( learn_ert_with_newclass )
  294. {
  295. FPCRandomForests::train ( fp, examples );
  296. } else {
  297. FPCRandomForests::train ( fp, examples_transfer );
  298. }
  299. FPCRandomForests::save ( cached_prior_structure );
  300. }
  301. fprintf (stderr, "MAP ESTIMATION sigmaq = %e\n", sigmaq);
  302. for ( vector<DecisionTree *>::iterator i = forest.begin();
  303. i != forest.end();
  304. i++ )
  305. {
  306. DecisionTree & tree = *(*i);
  307. dte.reestimate ( tree,
  308. examples_new,
  309. sigmaq,
  310. newClassNo,
  311. muClasses,
  312. substituteClasses,
  313. maxClassNo);
  314. if ( partial_ml_estimation )
  315. {
  316. partialMLEstimate ( tree,
  317. examples_new,
  318. newClassNo,
  319. partial_ml_estimation_depth );
  320. }
  321. if ( extend_map_tree )
  322. {
  323. fp.initRandomFeatureSelection ();
  324. extendMapTree ( fp,
  325. tree,
  326. examples_transfer,
  327. examples_new,
  328. newClassNo,
  329. muClasses);
  330. }
  331. }
  332. save ( "map.tree" );
  333. }
  334. FeaturePoolClassifier *FPCRandomForestTransfer::clone () const
  335. {
  336. fprintf (stderr, "FPCRandomForestTransfer::clone() not yet implemented !\n");
  337. exit(-1);
  338. }