DTEstimateAPriori.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. /**
  2. * @file DTEstimateAPriori.cpp
  3. * @brief estimate decision structure using a priori density
  4. * @author Erik Rodner
  5. * @date 05/27/2008
  6. */
  7. #include <iostream>
  8. #include "vislearning/classifier/fpclassifier/randomforest/DTEstimateAPriori.h"
  9. #include "vislearning/optimization/mapestimation/MAPMultinomialGaussianPrior.h"
  10. #include "vislearning/optimization/mapestimation/MAPMultinomialDirichlet.h"
  11. using namespace OBJREC;
  12. using namespace std;
  13. using namespace NICE;
  14. #define DEBUG_DTESTIMATE
  15. DTEstimateAPriori::DTEstimateAPriori( const Config *conf, const std::string & section )
  16. {
  17. std::string mapEstimatorType_s = conf->gS(section, "map_multinomial_estimator",
  18. "gaussianprior" );
  19. if ( mapEstimatorType_s == "gaussianprior" )
  20. map_estimator = new MAPMultinomialGaussianPrior();
  21. else if ( mapEstimatorType_s == "dirichletprior" )
  22. map_estimator = new MAPMultinomialDirichlet();
  23. else {
  24. fprintf (stderr, "DTEstimateAPriori: estimator type %s unknown\n", mapEstimatorType_s.c_str() );
  25. exit(-1);
  26. }
  27. }
  28. DTEstimateAPriori::~DTEstimateAPriori()
  29. {
  30. delete map_estimator;
  31. }
  32. void DTEstimateAPriori::reestimate ( DecisionTree & dt,
  33. Examples & examples,
  34. double sigmaq,
  35. int newClassNo,
  36. set<int> muClasses,
  37. set<int> substituteClasses,
  38. int maxClassNo )
  39. {
  40. mapEstimateClass ( dt,
  41. examples,
  42. newClassNo,
  43. muClasses,
  44. substituteClasses,
  45. sigmaq,
  46. maxClassNo );
  47. }
  48. /** calculating node probabilities recursive
  49. using the following formula:
  50. p(n | i) = p(p | i) ( c(i | n) c( i | p)^{-1} )
  51. @remark do not use normalized a posteriori values !
  52. */
  53. void DTEstimateAPriori::calculateNodeProbabilitiesRec (
  54. map<DecisionNode *, FullVector> & p,
  55. DecisionNode *node )
  56. {
  57. if ( node == NULL ) return;
  58. else if ( (node->left == NULL) && (node->right == NULL ) ) return;
  59. else {
  60. assert ( left != NULL );
  61. assert ( right != NULL );
  62. // estimate probabilies for children
  63. const FullVector & parent = p[node];
  64. // calculate left prob
  65. const FullVector & posteriori = node->distribution;
  66. const FullVector & posteriori_left = node->left->distribution;
  67. const FullVector & posteriori_right = node->right->distribution;
  68. FullVector result_left (parent);
  69. FullVector result_right (parent);
  70. FullVector transition_left ( posteriori_left );
  71. transition_left.divide ( posteriori );
  72. assert ( transition_left.max() <= 1.0 );
  73. assert ( transition_left.min() >= 0.0 );
  74. FullVector transition_right ( posteriori_right );
  75. transition_right.divide ( posteriori );
  76. result_left.multiply ( transition_left );
  77. result_right.multiply ( transition_right );
  78. p.insert ( pair<DecisionNode *, FullVector> ( node->left, result_left ) );
  79. p.insert ( pair<DecisionNode *, FullVector> ( node->right, result_right ) );
  80. calculateNodeProbabilitiesRec ( p, node->left );
  81. calculateNodeProbabilitiesRec ( p, node->right );
  82. }
  83. }
  84. void DTEstimateAPriori::calculateNodeProbabilities ( map<DecisionNode *, FullVector> & p, DecisionTree & tree )
  85. {
  86. DecisionNode *root = tree.getRoot();
  87. FullVector rootNP ( root->distribution.size() );
  88. // root node probability is 1 for each class
  89. rootNP.set ( 1.0 );
  90. p.insert ( pair<DecisionNode *, FullVector> ( root, rootNP ) );
  91. calculateNodeProbabilitiesRec ( p, root );
  92. }
  93. void DTEstimateAPriori::calculateNodeProbVec
  94. ( map<DecisionNode *, FullVector> & nodeProbs,
  95. int classno,
  96. // refactor-nice.pl: check this substitution
  97. // old: Vector & p )
  98. NICE::Vector & p )
  99. {
  100. double sum = 0.0;
  101. assert ( p.size() == 0 );
  102. for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
  103. k != nodeProbs.end(); k++ )
  104. {
  105. const FullVector & v = k->second;
  106. DecisionNode *node = k->first;
  107. if ( (node->left != NULL) || (node->right != NULL) )
  108. continue;
  109. double val = v[classno];
  110. NICE::Vector single (1);
  111. single[0] = val;
  112. // inefficient !!
  113. p.append ( single );
  114. sum += val;
  115. }
  116. for ( size_t i = 0 ; i < p.size() ; i++ )
  117. p[i] /= sum;
  118. }
  119. double DTEstimateAPriori::calcInnerNodeProbs (
  120. DecisionNode *node,
  121. map<DecisionNode *, double> & p )
  122. {
  123. map<DecisionNode *, double>::const_iterator i = p.find( node );
  124. if ( i == p.end() )
  125. {
  126. double prob = 0.0;
  127. if ( node->left != NULL )
  128. prob += calcInnerNodeProbs ( node->left, p );
  129. if ( node->right != NULL )
  130. prob += calcInnerNodeProbs ( node->right, p );
  131. p.insert ( pair<DecisionNode *, double> ( node, prob ) );
  132. return prob;
  133. } else {
  134. return i->second;
  135. }
  136. }
  137. /** calculates a-posteriori probabilities using the formula:
  138. p(i | n) = p(n | i) p(i) ( \sum_j p(n | j) p(j) )^{-1}
  139. */
  140. void DTEstimateAPriori::calcPosteriori (
  141. DecisionNode *node,
  142. const FullVector & apriori,
  143. const map<DecisionNode *, FullVector> & nodeprob,
  144. map<DecisionNode *, FullVector> & posterioriResult )
  145. {
  146. if ( node == NULL ) return;
  147. map<DecisionNode *, FullVector>::const_iterator i;
  148. i = nodeprob.find ( node );
  149. assert ( i != nodeprob.end() );
  150. const FullVector & np = i->second;
  151. assert ( np.sum() > 10e-7 );
  152. FullVector joint ( np );
  153. joint.multiply ( apriori );
  154. joint.normalize();
  155. posterioriResult.insert ( pair<DecisionNode *, FullVector> ( node, joint ) );
  156. calcPosteriori (node->left, apriori, nodeprob, posterioriResult);
  157. calcPosteriori (node->right, apriori, nodeprob, posterioriResult);
  158. }
  159. /** calculates a-posteriori probabilities by substituting support class
  160. values with new ones
  161. */
  162. void DTEstimateAPriori::calcPosteriori ( DecisionTree & tree,
  163. const FullVector & aprioriOld,
  164. const map<DecisionNode *, FullVector> & nodeprobOld,
  165. map<DecisionNode *, double> & nodeprobNew,
  166. const set<int> & substituteClasses,
  167. int newClassNo,
  168. map<DecisionNode *, FullVector> & posterioriResult )
  169. {
  170. // calculating node probabilities of inner nodes
  171. calcInnerNodeProbs ( tree.getRoot(), nodeprobNew );
  172. // building new apriori probabilities
  173. FullVector apriori ( aprioriOld );
  174. for ( int i = 0 ; i < apriori.size() ; i++ )
  175. if ( substituteClasses.find( i ) != substituteClasses.end() )
  176. apriori[i] = 0.0;
  177. if ( substituteClasses.size() > 0 )
  178. apriori[newClassNo] = 1.0 - apriori.sum();
  179. else {
  180. // mean a priori
  181. double avg = apriori.sum() / apriori.size();
  182. apriori[newClassNo] = avg;
  183. apriori.normalize();
  184. }
  185. if ( substituteClasses.size() > 0 )
  186. {
  187. fprintf (stderr, "WARNING: do you really want to do class substitution ?\n");
  188. }
  189. // building new node probabilities
  190. map<DecisionNode *, FullVector> nodeprob;
  191. for ( map<DecisionNode *, FullVector>::const_iterator j = nodeprobOld.begin();
  192. j != nodeprobOld.end();
  193. j++ )
  194. {
  195. const FullVector & d = j->second;
  196. DecisionNode *node = j->first;
  197. map<DecisionNode *, double>::const_iterator k = nodeprobNew.find ( node );
  198. assert ( k != nodeprobNew.end() );
  199. double newNP = k->second;
  200. assert ( d.sum() > 10e-7 );
  201. FullVector np ( d );
  202. for ( int i = 0 ; i < d.size() ; i++ )
  203. if ( substituteClasses.find( i ) != substituteClasses.end() )
  204. np[i] = 0.0;
  205. if ( (np[ newClassNo ] > 10e-7) && (newNP < 10e-7) ) {
  206. fprintf (stderr, "DTEstimateAPriori: handling special case!\n");
  207. } else {
  208. np[ newClassNo ] = newNP;
  209. }
  210. if ( np.sum() < 10e-7 )
  211. {
  212. fprintf (stderr, "DTEstimateAPriori: handling special case (2), mostly for binary tasks!\n");
  213. assert ( substituteClasses.size() == 1 );
  214. int oldClassNo = *(substituteClasses.begin());
  215. np[ newClassNo ] = d[ oldClassNo ];
  216. }
  217. nodeprob.insert ( pair<DecisionNode *, FullVector> ( node, np ) );
  218. }
  219. calcPosteriori ( tree.getRoot(), apriori, nodeprob, posterioriResult );
  220. }
  221. void DTEstimateAPriori::mapEstimateClass ( DecisionTree & tree,
  222. Examples & new_examples,
  223. int newClassNo,
  224. set<int> muClasses,
  225. set<int> substituteClasses,
  226. double sigmaq,
  227. int maxClassNo )
  228. {
  229. // ----------- (0) class a priori information ---------------------------------------------
  230. FullVector & root_distribution = tree.getRoot()->distribution;
  231. FullVector apriori ( root_distribution );
  232. apriori.normalize();
  233. // ----------- (1) collect leaf probabilities of oldClassNo -> mu -------------------------
  234. fprintf (stderr, "DTEstimateAPriori: calculating mu vector\n");
  235. map<DecisionNode *, FullVector> nodeProbs;
  236. calculateNodeProbabilities ( nodeProbs, tree );
  237. VVector priorDistributionSamples;
  238. for ( set<int>::const_iterator i = muClasses.begin();
  239. i != muClasses.end();
  240. i++ )
  241. {
  242. NICE::Vector p;
  243. calculateNodeProbVec ( nodeProbs, *i, p );
  244. priorDistributionSamples.push_back(p);
  245. }
  246. // ----------- (2) infer examples_new into tree -> leaf prob counts -----------------------
  247. FullVector distribution ( maxClassNo+1 );
  248. fprintf (stderr, "DTEstimateAPriori: Infering %d new examples into the tree\n", (int)new_examples.size() );
  249. assert ( new_examples.size() > 0 );
  250. tree.resetCounters ();
  251. for ( Examples::iterator j = new_examples.begin() ;
  252. j != new_examples.end() ;
  253. j++ )
  254. tree.traverse ( j->second, distribution );
  255. // refactor-nice.pl: check this substitution
  256. // old: Vector scores;
  257. vector<double> scores_stl;
  258. for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
  259. k != nodeProbs.end(); k++ )
  260. {
  261. DecisionNode *node = k->first;
  262. if ( (node->left != NULL) || (node->right != NULL) )
  263. continue;
  264. scores_stl.push_back ( node->counter );
  265. }
  266. NICE::Vector scores (scores_stl);
  267. VVector likelihoodDistributionSamples;
  268. likelihoodDistributionSamples.push_back ( scores );
  269. // ------------------------------- (3) map estimation ------------------------------------------
  270. fprintf (stderr, "DTEstimateAPriori: MAP estimation ...sigmaq = %e\n", sigmaq);
  271. NICE::Vector theta;
  272. // scores = ML solution = counts in each leaf node
  273. // theta = solution of the MAP estimation
  274. map_estimator->estimate ( theta, likelihoodDistributionSamples, priorDistributionSamples, sigmaq );
  275. assert ( theta.size() == scores.size() );
  276. // compute normalized scores
  277. NICE::Vector scores_n ( scores );
  278. double sum = 0.0;
  279. for ( int k = 0 ; k < (int)scores_n.size() ; k++ )
  280. sum += scores_n[k];
  281. assert ( fabs(sum) > 10e-8 );
  282. for ( int k = 0 ; k < (int)scores_n.size() ; k++ )
  283. scores_n[k] /= sum;
  284. // ---------- (4) calculate posteriori probs in each leaf according to leaf probs ---------------------
  285. map<DecisionNode *, double> npMAP;
  286. long index = 0;
  287. for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
  288. k != nodeProbs.end(); k++ )
  289. {
  290. DecisionNode *node = k->first;
  291. if ( (node->left != NULL) || (node->right != NULL) )
  292. continue;
  293. npMAP[node] = theta[index];
  294. index++;
  295. }
  296. map<DecisionNode *, FullVector> posteriori;
  297. calcPosteriori ( tree, apriori, nodeProbs, npMAP, substituteClasses,
  298. newClassNo, posteriori );
  299. // (5) substitute class scores
  300. for ( map<DecisionNode *, FullVector>::iterator i = posteriori.begin();
  301. i != posteriori.end();
  302. i++ )
  303. {
  304. DecisionNode *node = i->first;
  305. if ( (node->left != NULL) || (node->right != NULL) )
  306. {
  307. //fprintf (stderr, "MAPMultinomialGaussianPrior: reestimating prob of a inner node !\n");
  308. continue;
  309. }
  310. #ifdef DEBUG_DTESTIMATE
  311. FullVector old_distribution ( node->distribution );
  312. old_distribution.normalize();
  313. old_distribution.store (cerr);
  314. #endif
  315. for ( int k = 0 ; k < node->distribution.size() ; k++ )
  316. if ( substituteClasses.find( k ) != substituteClasses.end() )
  317. node->distribution[k] = 0.0;
  318. // recalculate probabilities in weights
  319. double oldvalue = node->distribution.get(newClassNo);
  320. double supportsum = node->distribution.sum() - oldvalue;
  321. double pgamma = i->second[newClassNo];
  322. if ( (fabs(supportsum) > 10e-11) && (fabs(1.0-pgamma) < 10e-11 ) )
  323. {
  324. fprintf (stderr, "DTEstimateAPriori: corrupted probabilities\n");
  325. fprintf (stderr, "sum of all other class: %f\n", supportsum );
  326. fprintf (stderr, "prob of new class: %f\n", pgamma );
  327. exit(-1);
  328. }
  329. double newvalue = 0.0;
  330. if ( fabs(supportsum) < 10e-11 )
  331. newvalue = 0.0;
  332. else
  333. newvalue = supportsum * pgamma / (1.0 - pgamma);
  334. if ( (muClasses.size() == 1) && (substituteClasses.size() == 0) )
  335. {
  336. double muvalue = node->distribution.get( *(muClasses.begin()) );
  337. #ifdef DEBUG_DTESTIMATE
  338. fprintf (stderr, "#REESTIMATE old=%f new=%f mu=%f pgamma=%f likelihood_prob=%f estimated_prob=%f\n", oldvalue,
  339. newvalue, muvalue, pgamma, nodeProbs[node][newClassNo], npMAP[node] );
  340. fprintf (stderr, "#REESTIMATE mu_prob=%f\n", nodeProbs[node][ *(muClasses.begin()) ] );
  341. #endif
  342. } else {
  343. #ifdef DEBUG_DTESTIMATE
  344. fprintf (stderr, "#REESTIMATE old=%f new=%f pgamma=%f supportsum=%f\n", oldvalue, newvalue, pgamma, supportsum );
  345. #endif
  346. }
  347. //if ( newvalue > oldvalue )
  348. node->distribution[newClassNo] = newvalue;
  349. #ifdef DEBUG_DTESTIMATE
  350. FullVector new_distribution ( node->distribution );
  351. // new_distribution.normalize();
  352. // new_distribution.store (cerr);
  353. /*
  354. for ( int i = 0 ; i < new_distribution.size() ; i++ )
  355. {
  356. if ( (muClasses.find(i) != muClasses.end()) || ( i == newClassNo ) )
  357. continue;
  358. if ( new_distribution[i] != old_distribution[i] )
  359. {
  360. fprintf (stderr, "class %d %f <-> %f\n", i, new_distribution[i], old_distribution[i] );
  361. new_distribution.store ( cerr );
  362. old_distribution.store ( cerr );
  363. node->distribution.store ( cerr );
  364. exit(-1);
  365. }
  366. }
  367. */
  368. #endif
  369. }
  370. int count, depth;
  371. tree.statistics ( depth, count );
  372. assert ( count == (int)posteriori.size() );
  373. tree.pruneTreeScore ( 10e-10 );
  374. }