DTEstimateAPriori.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. /**
  2. * @file DTEstimateAPriori.cpp
  3. * @brief estimate decision structure using a priori density
  4. * @author Erik Rodner
  5. * @date 05/27/2008
  6. */
  7. #ifdef NOVISUAL
  8. #include <vislearning/nice_nonvis.h>
  9. #else
  10. #include <vislearning/nice.h>
  11. #endif
  12. #include <iostream>
  13. #include "vislearning/classifier/fpclassifier/randomforest/DTEstimateAPriori.h"
  14. #include "vislearning/optimization/mapestimation/MAPMultinomialGaussianPrior.h"
  15. #include "vislearning/optimization/mapestimation/MAPMultinomialDirichlet.h"
  16. using namespace OBJREC;
  17. using namespace std;
  18. using namespace NICE;
  19. #define DEBUG_DTESTIMATE
  20. DTEstimateAPriori::DTEstimateAPriori( const Config *conf, const std::string & section )
  21. {
  22. std::string mapEstimatorType_s = conf->gS(section, "map_multinomial_estimator",
  23. "gaussianprior" );
  24. if ( mapEstimatorType_s == "gaussianprior" )
  25. map_estimator = new MAPMultinomialGaussianPrior();
  26. else if ( mapEstimatorType_s == "dirichletprior" )
  27. map_estimator = new MAPMultinomialDirichlet();
  28. else {
  29. fprintf (stderr, "DTEstimateAPriori: estimator type %s unknown\n", mapEstimatorType_s.c_str() );
  30. exit(-1);
  31. }
  32. }
  33. DTEstimateAPriori::~DTEstimateAPriori()
  34. {
  35. delete map_estimator;
  36. }
  37. void DTEstimateAPriori::reestimate ( DecisionTree & dt,
  38. Examples & examples,
  39. double sigmaq,
  40. int newClassNo,
  41. set<int> muClasses,
  42. set<int> substituteClasses,
  43. int maxClassNo )
  44. {
  45. mapEstimateClass ( dt,
  46. examples,
  47. newClassNo,
  48. muClasses,
  49. substituteClasses,
  50. sigmaq,
  51. maxClassNo );
  52. }
  53. /** calculating node probabilities recursive
  54. using the following formula:
  55. p(n | i) = p(p | i) ( c(i | n) c( i | p)^{-1} )
  56. @remark do not use normalized a posteriori values !
  57. */
  58. void DTEstimateAPriori::calculateNodeProbabilitiesRec (
  59. map<DecisionNode *, FullVector> & p,
  60. DecisionNode *node )
  61. {
  62. if ( node == NULL ) return;
  63. else if ( (node->left == NULL) && (node->right == NULL ) ) return;
  64. else {
  65. assert ( left != NULL );
  66. assert ( right != NULL );
  67. // estimate probabilies for children
  68. const FullVector & parent = p[node];
  69. // calculate left prob
  70. const FullVector & posteriori = node->distribution;
  71. const FullVector & posteriori_left = node->left->distribution;
  72. const FullVector & posteriori_right = node->right->distribution;
  73. FullVector result_left (parent);
  74. FullVector result_right (parent);
  75. FullVector transition_left ( posteriori_left );
  76. transition_left.divide ( posteriori );
  77. assert ( transition_left.max() <= 1.0 );
  78. assert ( transition_left.min() >= 0.0 );
  79. FullVector transition_right ( posteriori_right );
  80. transition_right.divide ( posteriori );
  81. result_left.multiply ( transition_left );
  82. result_right.multiply ( transition_right );
  83. p.insert ( pair<DecisionNode *, FullVector> ( node->left, result_left ) );
  84. p.insert ( pair<DecisionNode *, FullVector> ( node->right, result_right ) );
  85. calculateNodeProbabilitiesRec ( p, node->left );
  86. calculateNodeProbabilitiesRec ( p, node->right );
  87. }
  88. }
  89. void DTEstimateAPriori::calculateNodeProbabilities ( map<DecisionNode *, FullVector> & p, DecisionTree & tree )
  90. {
  91. DecisionNode *root = tree.getRoot();
  92. FullVector rootNP ( root->distribution.size() );
  93. // root node probability is 1 for each class
  94. rootNP.set ( 1.0 );
  95. p.insert ( pair<DecisionNode *, FullVector> ( root, rootNP ) );
  96. calculateNodeProbabilitiesRec ( p, root );
  97. }
  98. void DTEstimateAPriori::calculateNodeProbVec
  99. ( map<DecisionNode *, FullVector> & nodeProbs,
  100. int classno,
  101. // refactor-nice.pl: check this substitution
  102. // old: Vector & p )
  103. NICE::Vector & p )
  104. {
  105. double sum = 0.0;
  106. assert ( p.size() == 0 );
  107. for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
  108. k != nodeProbs.end(); k++ )
  109. {
  110. const FullVector & v = k->second;
  111. DecisionNode *node = k->first;
  112. if ( (node->left != NULL) || (node->right != NULL) )
  113. continue;
  114. double val = v[classno];
  115. NICE::Vector single (1);
  116. single[0] = val;
  117. // inefficient !!
  118. p.append ( single );
  119. sum += val;
  120. }
  121. for ( size_t i = 0 ; i < p.size() ; i++ )
  122. p[i] /= sum;
  123. }
  124. double DTEstimateAPriori::calcInnerNodeProbs (
  125. DecisionNode *node,
  126. map<DecisionNode *, double> & p )
  127. {
  128. map<DecisionNode *, double>::const_iterator i = p.find( node );
  129. if ( i == p.end() )
  130. {
  131. double prob = 0.0;
  132. if ( node->left != NULL )
  133. prob += calcInnerNodeProbs ( node->left, p );
  134. if ( node->right != NULL )
  135. prob += calcInnerNodeProbs ( node->right, p );
  136. p.insert ( pair<DecisionNode *, double> ( node, prob ) );
  137. return prob;
  138. } else {
  139. return i->second;
  140. }
  141. }
  142. /** calculates a-posteriori probabilities using the formula:
  143. p(i | n) = p(n | i) p(i) ( \sum_j p(n | j) p(j) )^{-1}
  144. */
  145. void DTEstimateAPriori::calcPosteriori (
  146. DecisionNode *node,
  147. const FullVector & apriori,
  148. const map<DecisionNode *, FullVector> & nodeprob,
  149. map<DecisionNode *, FullVector> & posterioriResult )
  150. {
  151. if ( node == NULL ) return;
  152. map<DecisionNode *, FullVector>::const_iterator i;
  153. i = nodeprob.find ( node );
  154. assert ( i != nodeprob.end() );
  155. const FullVector & np = i->second;
  156. assert ( np.sum() > 10e-7 );
  157. FullVector joint ( np );
  158. joint.multiply ( apriori );
  159. joint.normalize();
  160. posterioriResult.insert ( pair<DecisionNode *, FullVector> ( node, joint ) );
  161. calcPosteriori (node->left, apriori, nodeprob, posterioriResult);
  162. calcPosteriori (node->right, apriori, nodeprob, posterioriResult);
  163. }
  164. /** calculates a-posteriori probabilities by substituting support class
  165. values with new ones
  166. */
  167. void DTEstimateAPriori::calcPosteriori ( DecisionTree & tree,
  168. const FullVector & aprioriOld,
  169. const map<DecisionNode *, FullVector> & nodeprobOld,
  170. map<DecisionNode *, double> & nodeprobNew,
  171. const set<int> & substituteClasses,
  172. int newClassNo,
  173. map<DecisionNode *, FullVector> & posterioriResult )
  174. {
  175. // calculating node probabilities of inner nodes
  176. calcInnerNodeProbs ( tree.getRoot(), nodeprobNew );
  177. // building new apriori probabilities
  178. FullVector apriori ( aprioriOld );
  179. for ( int i = 0 ; i < apriori.size() ; i++ )
  180. if ( substituteClasses.find( i ) != substituteClasses.end() )
  181. apriori[i] = 0.0;
  182. if ( substituteClasses.size() > 0 )
  183. apriori[newClassNo] = 1.0 - apriori.sum();
  184. else {
  185. // mean a priori
  186. double avg = apriori.sum() / apriori.size();
  187. apriori[newClassNo] = avg;
  188. apriori.normalize();
  189. }
  190. if ( substituteClasses.size() > 0 )
  191. {
  192. fprintf (stderr, "WARNING: do you really want to do class substitution ?\n");
  193. }
  194. // building new node probabilities
  195. map<DecisionNode *, FullVector> nodeprob;
  196. for ( map<DecisionNode *, FullVector>::const_iterator j = nodeprobOld.begin();
  197. j != nodeprobOld.end();
  198. j++ )
  199. {
  200. const FullVector & d = j->second;
  201. DecisionNode *node = j->first;
  202. map<DecisionNode *, double>::const_iterator k = nodeprobNew.find ( node );
  203. assert ( k != nodeprobNew.end() );
  204. double newNP = k->second;
  205. assert ( d.sum() > 10e-7 );
  206. FullVector np ( d );
  207. for ( int i = 0 ; i < d.size() ; i++ )
  208. if ( substituteClasses.find( i ) != substituteClasses.end() )
  209. np[i] = 0.0;
  210. if ( (np[ newClassNo ] > 10e-7) && (newNP < 10e-7) ) {
  211. fprintf (stderr, "DTEstimateAPriori: handling special case!\n");
  212. } else {
  213. np[ newClassNo ] = newNP;
  214. }
  215. if ( np.sum() < 10e-7 )
  216. {
  217. fprintf (stderr, "DTEstimateAPriori: handling special case (2), mostly for binary tasks!\n");
  218. assert ( substituteClasses.size() == 1 );
  219. int oldClassNo = *(substituteClasses.begin());
  220. np[ newClassNo ] = d[ oldClassNo ];
  221. }
  222. nodeprob.insert ( pair<DecisionNode *, FullVector> ( node, np ) );
  223. }
  224. calcPosteriori ( tree.getRoot(), apriori, nodeprob, posterioriResult );
  225. }
  226. void DTEstimateAPriori::mapEstimateClass ( DecisionTree & tree,
  227. Examples & new_examples,
  228. int newClassNo,
  229. set<int> muClasses,
  230. set<int> substituteClasses,
  231. double sigmaq,
  232. int maxClassNo )
  233. {
  234. // ----------- (0) class a priori information ---------------------------------------------
  235. FullVector & root_distribution = tree.getRoot()->distribution;
  236. FullVector apriori ( root_distribution );
  237. apriori.normalize();
  238. // ----------- (1) collect leaf probabilities of oldClassNo -> mu -------------------------
  239. fprintf (stderr, "DTEstimateAPriori: calculating mu vector\n");
  240. map<DecisionNode *, FullVector> nodeProbs;
  241. calculateNodeProbabilities ( nodeProbs, tree );
  242. VVector priorDistributionSamples;
  243. for ( set<int>::const_iterator i = muClasses.begin();
  244. i != muClasses.end();
  245. i++ )
  246. {
  247. NICE::Vector p;
  248. calculateNodeProbVec ( nodeProbs, *i, p );
  249. priorDistributionSamples.push_back(p);
  250. }
  251. // ----------- (2) infer examples_new into tree -> leaf prob counts -----------------------
  252. FullVector distribution ( maxClassNo+1 );
  253. fprintf (stderr, "DTEstimateAPriori: Infering %d new examples into the tree\n", (int)new_examples.size() );
  254. assert ( new_examples.size() > 0 );
  255. tree.resetCounters ();
  256. for ( Examples::iterator j = new_examples.begin() ;
  257. j != new_examples.end() ;
  258. j++ )
  259. tree.traverse ( j->second, distribution );
  260. // refactor-nice.pl: check this substitution
  261. // old: Vector scores;
  262. vector<double> scores_stl;
  263. for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
  264. k != nodeProbs.end(); k++ )
  265. {
  266. DecisionNode *node = k->first;
  267. if ( (node->left != NULL) || (node->right != NULL) )
  268. continue;
  269. scores_stl.push_back ( node->counter );
  270. }
  271. NICE::Vector scores (scores_stl);
  272. VVector likelihoodDistributionSamples;
  273. likelihoodDistributionSamples.push_back ( scores );
  274. // ------------------------------- (3) map estimation ------------------------------------------
  275. fprintf (stderr, "DTEstimateAPriori: MAP estimation ...sigmaq = %e\n", sigmaq);
  276. NICE::Vector theta;
  277. // scores = ML solution = counts in each leaf node
  278. // theta = solution of the MAP estimation
  279. map_estimator->estimate ( theta, likelihoodDistributionSamples, priorDistributionSamples, sigmaq );
  280. assert ( theta.size() == scores.size() );
  281. // compute normalized scores
  282. NICE::Vector scores_n ( scores );
  283. double sum = 0.0;
  284. for ( int k = 0 ; k < (int)scores_n.size() ; k++ )
  285. sum += scores_n[k];
  286. assert ( fabs(sum) > 10e-8 );
  287. for ( int k = 0 ; k < (int)scores_n.size() ; k++ )
  288. scores_n[k] /= sum;
  289. // ---------- (4) calculate posteriori probs in each leaf according to leaf probs ---------------------
  290. map<DecisionNode *, double> npMAP;
  291. long index = 0;
  292. for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
  293. k != nodeProbs.end(); k++ )
  294. {
  295. DecisionNode *node = k->first;
  296. if ( (node->left != NULL) || (node->right != NULL) )
  297. continue;
  298. npMAP[node] = theta[index];
  299. index++;
  300. }
  301. map<DecisionNode *, FullVector> posteriori;
  302. calcPosteriori ( tree, apriori, nodeProbs, npMAP, substituteClasses,
  303. newClassNo, posteriori );
  304. // (5) substitute class scores
  305. for ( map<DecisionNode *, FullVector>::iterator i = posteriori.begin();
  306. i != posteriori.end();
  307. i++ )
  308. {
  309. DecisionNode *node = i->first;
  310. if ( (node->left != NULL) || (node->right != NULL) )
  311. {
  312. //fprintf (stderr, "MAPMultinomialGaussianPrior: reestimating prob of a inner node !\n");
  313. continue;
  314. }
  315. #ifdef DEBUG_DTESTIMATE
  316. FullVector old_distribution ( node->distribution );
  317. old_distribution.normalize();
  318. old_distribution.store (cerr);
  319. #endif
  320. for ( int k = 0 ; k < node->distribution.size() ; k++ )
  321. if ( substituteClasses.find( k ) != substituteClasses.end() )
  322. node->distribution[k] = 0.0;
  323. // recalculate probabilities in weights
  324. double oldvalue = node->distribution.get(newClassNo);
  325. double supportsum = node->distribution.sum() - oldvalue;
  326. double pgamma = i->second[newClassNo];
  327. if ( (fabs(supportsum) > 10e-11) && (fabs(1.0-pgamma) < 10e-11 ) )
  328. {
  329. fprintf (stderr, "DTEstimateAPriori: corrupted probabilities\n");
  330. fprintf (stderr, "sum of all other class: %f\n", supportsum );
  331. fprintf (stderr, "prob of new class: %f\n", pgamma );
  332. exit(-1);
  333. }
  334. double newvalue = 0.0;
  335. if ( fabs(supportsum) < 10e-11 )
  336. newvalue = 0.0;
  337. else
  338. newvalue = supportsum * pgamma / (1.0 - pgamma);
  339. if ( (muClasses.size() == 1) && (substituteClasses.size() == 0) )
  340. {
  341. double muvalue = node->distribution.get( *(muClasses.begin()) );
  342. #ifdef DEBUG_DTESTIMATE
  343. fprintf (stderr, "#REESTIMATE old=%f new=%f mu=%f pgamma=%f likelihood_prob=%f estimated_prob=%f\n", oldvalue,
  344. newvalue, muvalue, pgamma, nodeProbs[node][newClassNo], npMAP[node] );
  345. fprintf (stderr, "#REESTIMATE mu_prob=%f\n", nodeProbs[node][ *(muClasses.begin()) ] );
  346. #endif
  347. } else {
  348. #ifdef DEBUG_DTESTIMATE
  349. fprintf (stderr, "#REESTIMATE old=%f new=%f pgamma=%f supportsum=%f\n", oldvalue, newvalue, pgamma, supportsum );
  350. #endif
  351. }
  352. //if ( newvalue > oldvalue )
  353. node->distribution[newClassNo] = newvalue;
  354. #ifdef DEBUG_DTESTIMATE
  355. FullVector new_distribution ( node->distribution );
  356. // new_distribution.normalize();
  357. // new_distribution.store (cerr);
  358. /*
  359. for ( int i = 0 ; i < new_distribution.size() ; i++ )
  360. {
  361. if ( (muClasses.find(i) != muClasses.end()) || ( i == newClassNo ) )
  362. continue;
  363. if ( new_distribution[i] != old_distribution[i] )
  364. {
  365. fprintf (stderr, "class %d %f <-> %f\n", i, new_distribution[i], old_distribution[i] );
  366. new_distribution.store ( cerr );
  367. old_distribution.store ( cerr );
  368. node->distribution.store ( cerr );
  369. exit(-1);
  370. }
  371. }
  372. */
  373. #endif
  374. }
  375. int count, depth;
  376. tree.statistics ( depth, count );
  377. assert ( count == (int)posteriori.size() );
  378. tree.pruneTreeScore ( 10e-10 );
  379. }