RTBRandom.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /**
  2. * @file RTBRandom.cpp
  3. * @brief random regression tree
  4. * @author Sven Sickert
  5. * @date 06/19/2013
  6. */
  7. #include <iostream>
  8. #include "RTBRandom.h"
  9. using namespace OBJREC;
  10. #undef DEBUGTREE
  11. #undef DETAILTREE
  12. using namespace std;
  13. using namespace NICE;
  14. RTBRandom::RTBRandom( const Config *conf, std::string section )
  15. {
  16. random_split_tests = conf->gI(section, "random_split_tests", 10 );
  17. random_features = conf->gI(section, "random_features", 500 );
  18. max_depth = conf->gI(section, "max_depth", 10 );
  19. min_examples = conf->gI(section, "min_examples", 50);
  20. minimum_error_reduction = conf->gD("RandomForest", "minimum_error_reduction", 10e-3 );
  21. save_indices = conf->gB(section, "save_indices", false);
  22. if ( conf->gB(section, "start_random_generator", false ) )
  23. srand(time(NULL));
  24. }
  25. RTBRandom::~RTBRandom()
  26. {
  27. }
  28. bool RTBRandom::errorReductionLeftRight(const vector< pair< double, int > > values,
  29. const Vector & y,
  30. double threshold,
  31. double& error_left,
  32. double& error_right,
  33. int& count_left,
  34. int& count_right)
  35. {
  36. count_left = 0;
  37. count_right = 0;
  38. vector<int> selection_left;
  39. vector<int> selection_right;
  40. for ( vector< pair< double, int > >::const_iterator it = values.begin();
  41. it != values.end(); it++ )
  42. {
  43. double value = it->first;
  44. if ( value < threshold )
  45. {
  46. count_left++;
  47. selection_left.push_back( it->second );
  48. }
  49. else
  50. {
  51. count_right++;
  52. selection_right.push_back( it->second );
  53. }
  54. }
  55. // if ( (count_left == 0) || (count_right == 0) )
  56. // return false; // no split
  57. if ( (count_left < min_examples) || (count_right < min_examples) )
  58. return false; // no split
  59. RegressionNode *left = new RegressionNode ();
  60. left->nodePrediction( y, selection_left );
  61. error_left = left->lsError;
  62. delete left;
  63. RegressionNode *right = new RegressionNode ();
  64. right->nodePrediction( y, selection_right );
  65. error_right = right->lsError;
  66. delete right;
  67. return true;
  68. }
  69. RegressionNode *RTBRandom::buildRecursive ( const NICE::VVector & x,
  70. const NICE::Vector & y,
  71. std::vector<int> & selection,
  72. int depth)
  73. {
  74. #ifdef DEBUGTREE
  75. fprintf (stderr, "Examples: %d (depth %d)\n", (int)selection.size(),
  76. (int)depth);
  77. #endif
  78. RegressionNode *node = new RegressionNode ();
  79. node->nodePrediction( y, selection );
  80. double lsError = node->lsError;
  81. if ( depth > max_depth )
  82. {
  83. #ifdef DEBUGTREE
  84. fprintf (stderr, "RTBRandom: maxmimum depth reached !\n");
  85. #endif
  86. node->trainExamplesIndices = selection;
  87. return node;
  88. }
  89. if ( (int)selection.size() < min_examples )
  90. {
  91. #ifdef DEBUGTREE
  92. fprintf (stderr, "RTBRandom: minimum examples reached %d < %d !\n",
  93. (int)selection.size(), min_examples );
  94. #endif
  95. node->trainExamplesIndices = selection;
  96. return node;
  97. }
  98. int best_feature = 0;
  99. double best_threshold = 0.0;
  100. double best_reduct = -1.0;
  101. vector<pair<double, int> > best_values;
  102. vector<pair<double, int> > values;
  103. double lsError_left = 0.0;
  104. double lsError_right = 0.0;
  105. for ( int k = 0; k < random_features; k++ )
  106. {
  107. #ifdef DETAILTREE
  108. fprintf (stderr, "calculating random feature %d\n", k );
  109. #endif
  110. int f = rand() % x[0].size();
  111. values.clear();
  112. collectFeatureValues ( x, selection, f, values );
  113. double minValue = (min_element ( values.begin(), values.end() ))->first;
  114. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  115. #ifdef DETAILTREE
  116. fprintf (stderr, "max %f min %f\n", maxValue, minValue );
  117. #endif
  118. if ( maxValue - minValue < 1e-7 ) continue;
  119. for ( int i = 0; i < random_split_tests; i++ )
  120. {
  121. double threshold;
  122. threshold = rand() * (maxValue -minValue ) / RAND_MAX + minValue;
  123. #ifdef DETAILTREE
  124. fprintf (stderr, "calculating split f/s(f) %d/%d %f\n", k, i, threshold );
  125. #endif
  126. lsError_left = 0.0;
  127. lsError_right = 0.0;
  128. int count_left, count_right;
  129. if ( ! errorReductionLeftRight( values, y, threshold, lsError_left,
  130. lsError_right, count_left, count_right) )
  131. continue;
  132. //double pl = (count_left) / (count_left +count_right);
  133. //double errorReduction = lsError - pl*lsError_left - (1-pl)*lsError_right;
  134. double errorReduction = lsError - lsError_left - lsError_right;
  135. if ( errorReduction > best_reduct )
  136. {
  137. best_reduct = errorReduction;
  138. best_threshold = threshold;
  139. best_feature = f;
  140. #ifdef DETAILTREE
  141. fprintf (stderr, "t %f for feature %i\n", best_threshold, best_feature );
  142. #endif
  143. }
  144. }
  145. }
  146. if ( best_reduct < minimum_error_reduction )
  147. {
  148. #ifdef DEBUGTREE
  149. fprintf (stderr, "RTBRandom: error reduction to small !\n");
  150. #endif
  151. node->trainExamplesIndices = selection;
  152. return node;
  153. }
  154. node->f = best_feature;
  155. node->threshold = best_threshold;
  156. // re calculating examples_left and examples_right
  157. vector<int> best_examples_left;
  158. vector<int> best_examples_right;
  159. values.clear();
  160. collectFeatureValues( x, selection, best_feature, values);
  161. best_examples_left.reserve ( values.size() / 2 );
  162. best_examples_right.reserve ( values.size() / 2 );
  163. for ( vector< pair < double, int > >::const_iterator it = values.begin();
  164. it != values.end(); it++ )
  165. {
  166. double value = it->first;
  167. if ( value < best_threshold )
  168. best_examples_left.push_back( it->second );
  169. else
  170. best_examples_right.push_back( it->second );
  171. }
  172. node->left = buildRecursive( x, y, best_examples_left, depth+1 );
  173. node->right = buildRecursive( x, y, best_examples_right, depth+1 );
  174. return node;
  175. }
  176. RegressionNode *RTBRandom::build( const NICE::VVector & x,
  177. const NICE::Vector & y )
  178. {
  179. int index = 0;
  180. vector<int> all;
  181. all.reserve ( y.size() );
  182. for ( uint i = 0; i < y.size(); i++ )
  183. {
  184. all.push_back( index );
  185. index++;
  186. }
  187. return buildRecursive( x, y, all, 0);
  188. }