RTBMeanPostImprovement.cpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. /**
  2. * @file RTBMeanPostImprovement.cpp
  3. * @brief random regression tree
  4. * @author Sven Sickert
  5. * @date 07/23/2013
  6. */
  7. #define _USE_MATH_DEFINES
  8. #include <iostream>
  9. #include <math.h>
  10. #include "RTBMeanPostImprovement.h"
  11. using namespace OBJREC;
  12. #undef DEBUGTREE
  13. #undef DETAILTREE
  14. using namespace std;
  15. using namespace NICE;
  16. RTBMeanPostImprovement::RTBMeanPostImprovement( const Config *conf, std::string section )
  17. {
  18. random_split_tests = conf->gI(section, "random_split_tests", 10 );
  19. random_features = conf->gI(section, "random_features", 500 );
  20. max_depth = conf->gI(section, "max_depth", 10 );
  21. min_examples = conf->gI(section, "min_examples", 50);
  22. minimum_improvement = conf->gD("RandomForest", "minimum_improvement", 10e-3 );
  23. save_indices = conf->gB(section, "save_indices", false);
  24. auto_bandwith = conf->gB(section, "auto_bandwith", true);
  25. if ( conf->gB(section, "start_random_generator", false ) )
  26. srand(time(NULL));
  27. }
  28. RTBMeanPostImprovement::~RTBMeanPostImprovement()
  29. {
  30. }
  31. bool RTBMeanPostImprovement::improvementLeftRight(const vector< pair< double, int > > values,
  32. const Vector & y,
  33. double threshold,
  34. vector<double> & empDist_left,
  35. vector<double> & empDist_right,
  36. int& count_left,
  37. int& count_right,
  38. double& h,
  39. double& p )
  40. {
  41. count_left = 0;
  42. count_right = 0;
  43. vector<double> selection_left;
  44. vector<double> selection_right;
  45. for ( vector< pair< double, int > >::const_iterator it = values.begin();
  46. it != values.end(); it++ )
  47. {
  48. if ( (it->first) < threshold )
  49. {
  50. count_left++;
  51. selection_left.push_back( y[ it->second ] );
  52. }
  53. else
  54. {
  55. count_right++;
  56. selection_right.push_back( y[ it->second ] );
  57. }
  58. }
  59. if ( (count_left < min_examples) || (count_right < min_examples) )
  60. return false; // no split
  61. Vector vleft ( selection_left );
  62. Vector vright ( selection_right );
  63. // empirical distribution [Taylor & Jones, 1996]
  64. for ( vector< pair< double, int > >::const_iterator it = values.begin();
  65. it != values.end(); it++ )
  66. {
  67. double yval = y[ it->second ];
  68. int smaller_left = 0;
  69. int smaller_right = 0;
  70. for ( int l = 0; l < count_left; l++ )
  71. {
  72. if ( selection_left[l] <= yval ) smaller_left++;
  73. }
  74. for ( int r = 0; r < count_right; r++ )
  75. {
  76. if ( selection_right[r] <= yval ) smaller_right++;
  77. }
  78. if ( (it->first) < threshold )
  79. {
  80. double emp = (double)(smaller_left)/(double)values.size();
  81. empDist_left.push_back( emp );
  82. } else {
  83. double emp = (double)(smaller_right)/(double)values.size();
  84. empDist_right.push_back( emp );
  85. }
  86. }
  87. // bandwidth parameter [Taylor & Jones, 1996]
  88. if (auto_bandwith)
  89. {
  90. double sigma_hat = sqrt( vleft.StdDev()*vleft.StdDev() + vright.StdDev()*vright.StdDev() );
  91. double z_hat = (double)( vleft.Mean() - vright.Mean() ) / sigma_hat;
  92. p = (double)count_left / (double)values.size();
  93. double tmp = (z_hat*z_hat - 1);
  94. h = sigma_hat / (double)( 2 * sqrt(M_PI) * p * (1-p) * tmp*tmp * gaussianVal(z_hat, 1.0) );
  95. }
  96. else
  97. h = 1.0;
  98. return true;
  99. }
  100. double RTBMeanPostImprovement::gaussianVal ( const double input,
  101. const double bandwidth )
  102. {
  103. return ( 1 / ( sqrt( 2 * M_PI ) * sqrt(2) * bandwidth ) * exp ( -0.25 * input * input ) );
  104. }
  105. RegressionNode *RTBMeanPostImprovement::buildRecursive ( const NICE::VVector & x,
  106. const NICE::Vector & y,
  107. std::vector<int> & selection,
  108. int depth)
  109. {
  110. #ifdef DEBUGTREE
  111. fprintf (stderr, "Examples: %d (depth %d)\n", (int)selection.size(),
  112. (int)depth);
  113. #endif
  114. RegressionNode *node = new RegressionNode ();
  115. node->nodePrediction( y, selection );
  116. double lsError = node->lsError;
  117. if ( depth > max_depth )
  118. {
  119. #ifdef DEBUGTREE
  120. fprintf (stderr, "RTBMeanPostImprovement: maxmimum depth reached !\n");
  121. #endif
  122. node->trainExamplesIndices = selection;
  123. return node;
  124. }
  125. if ( (int)selection.size() < min_examples )
  126. {
  127. #ifdef DEBUGTREE
  128. fprintf (stderr, "RTBMeanPostImprovement: minimum examples reached %d < %d !\n",
  129. (int)selection.size(), min_examples );
  130. #endif
  131. node->trainExamplesIndices = selection;
  132. return node;
  133. }
  134. int best_feature = 0;
  135. double best_threshold = 0.0;
  136. double best_improvement = -1.0;
  137. vector<pair<double, int> > values;
  138. for ( int k = 0; k < random_features; k++ )
  139. {
  140. #ifdef DETAILTREE
  141. fprintf (stderr, "calculating random feature %d\n", k );
  142. #endif
  143. int f = rand() % x[0].size();
  144. values.clear();
  145. collectFeatureValues ( x, selection, f, values );
  146. double minValue = (min_element ( values.begin(), values.end() ))->first;
  147. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  148. #ifdef DETAILTREE
  149. fprintf (stderr, "max %f min %f\n", maxValue, minValue );
  150. ofstream datafile;
  151. char buffer [20];
  152. int n = sprintf(buffer, "detailtree%d.dat", k);
  153. datafile.open( buffer );
  154. datafile << "# This file is called detailtree.dat" << endl;
  155. datafile << "# Data of the Mean Posterior Improvement Criterium" << endl;
  156. datafile << "# threshold \tI \t\tMPI" << endl;
  157. #endif
  158. if ( maxValue - minValue < 1e-7 ) continue;
  159. for ( int i = 0; i < random_split_tests; i++ )
  160. {
  161. double threshold;
  162. threshold = rand() * (maxValue -minValue ) / RAND_MAX + minValue;
  163. //double step = (maxValue - minValue) / random_split_tests;
  164. //threshold = minValue + i*step;
  165. #ifdef DETAILTREE
  166. fprintf (stderr, "calculating split f/s (t) %d/%d (%f)\n", k, i, threshold );
  167. #endif
  168. vector<double> empDist_left, empDist_right;
  169. int count_left, count_right;
  170. double h, p;
  171. if ( ! improvementLeftRight( values, y, threshold, empDist_left,
  172. empDist_right, count_left, count_right, h, p) )
  173. continue;
  174. // mean posterior improvement
  175. double I_hat = 0.0;
  176. for ( int l = 0; l < count_left; l++ )
  177. {
  178. for ( int r = 0; r < count_right; r++ )
  179. {
  180. I_hat += gaussianVal( (empDist_left[l] - empDist_right[r]), h );
  181. //I_hat += (empDist_left[l] - empDist_right[r]);
  182. }
  183. }
  184. I_hat /= ((double)count_left*(double)count_right);
  185. double mpi_hat = p * (1-p) * (1-I_hat);
  186. #ifdef DETAILTREE
  187. fprintf (stderr, "pL=%f, pR=%f, I=%f --> M=%f\n", p, (1-p), I_hat, mpi_hat);
  188. datafile << threshold << " " << I_hat << " " << mpi_hat << endl;
  189. #endif
  190. if ( mpi_hat > best_improvement )
  191. {
  192. best_improvement = mpi_hat;
  193. best_threshold = threshold;
  194. best_feature = f;
  195. }
  196. }
  197. #ifdef DETAILTREE
  198. datafile.close();
  199. #endif
  200. }
  201. #ifdef DETAILTREE
  202. fprintf (stderr, "t %f for feature %i\n", best_threshold, best_feature );
  203. #endif
  204. if ( best_improvement < minimum_improvement )
  205. {
  206. #ifdef DEBUGTREE
  207. fprintf (stderr, "RTBMeanPostImprovement: error reduction to small !\n");
  208. #endif
  209. node->trainExamplesIndices = selection;
  210. return node;
  211. }
  212. node->f = best_feature;
  213. node->threshold = best_threshold;
  214. // re calculating examples_left and examples_right
  215. vector<int> best_examples_left;
  216. vector<int> best_examples_right;
  217. values.clear();
  218. collectFeatureValues( x, selection, best_feature, values);
  219. best_examples_left.reserve ( values.size() / 2 );
  220. best_examples_right.reserve ( values.size() / 2 );
  221. for ( vector< pair < double, int > >::const_iterator it = values.begin();
  222. it != values.end(); it++ )
  223. {
  224. double value = it->first;
  225. if ( value < best_threshold )
  226. best_examples_left.push_back( it->second );
  227. else
  228. best_examples_right.push_back( it->second );
  229. }
  230. node->left = buildRecursive( x, y, best_examples_left, depth+1 );
  231. node->right = buildRecursive( x, y, best_examples_right, depth+1 );
  232. return node;
  233. }
  234. RegressionNode *RTBMeanPostImprovement::build( const NICE::VVector & x,
  235. const NICE::Vector & y )
  236. {
  237. int index = 0;
  238. vector<int> all;
  239. all.reserve ( y.size() );
  240. for ( uint i = 0; i < y.size(); i++ )
  241. {
  242. all.push_back( index );
  243. index++;
  244. }
  245. return buildRecursive( x, y, all, 0);
  246. }