RTBLinear.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. /**
  2. * @file RTBLinear.cpp
  3. * @brief random regression tree, which learns a LSE-model in every inner node during training
  4. * @author Frank Prüfer
  5. * @date 09/17/2013
  6. */
  7. #include <iostream>
  8. #include "RTBLinear.h"
  9. #include "vislearning/regression/linregression/LinRegression.h"
  10. using namespace OBJREC;
  11. #undef DEBUGTREE
  12. #undef DETAILTREE
  13. using namespace std;
  14. using namespace NICE;
  15. RTBLinear::RTBLinear( const Config *conf, std::string section )
  16. {
  17. random_split_tests = conf->gI(section, "random_split_tests", 10 );
  18. random_features = conf->gI(section, "random_features", 500 );
  19. max_depth = conf->gI(section, "max_depth", 10 );
  20. min_examples = conf->gI(section, "min_examples", 50);
  21. minimum_error_reduction = conf->gD("RandomForest", "minimum_error_reduction", 10e-3 );
  22. save_indices = conf->gB(section, "save_indices", false);
  23. if ( conf->gB(section, "start_random_generator", false ) )
  24. srand(time(NULL));
  25. }
  26. RTBLinear::~RTBLinear()
  27. {
  28. }
  29. void RTBLinear::computeLinearLSError( const VVector& x,
  30. const Vector& y,
  31. const int& numEx,
  32. double& lsError)
  33. {
  34. LinRegression *lreg = new LinRegression;
  35. lreg->teach ( x, y);
  36. NICE::Vector diff ( numEx );
  37. for ( int i = 0; i < numEx; i++ ){
  38. diff[i] = y[i] - lreg->predict ( x[i] );
  39. diff[i] *= diff[i];
  40. }
  41. lsError = diff.Mean();
  42. delete lreg;
  43. }
  44. bool RTBLinear::errorReductionLeftRight(const vector< pair< double, int > > values,
  45. const Vector & y,
  46. double threshold,
  47. double& error_left,
  48. double& error_right,
  49. int& count_left,
  50. int& count_right)
  51. {
  52. count_left = 0;
  53. count_right = 0;
  54. vector<int> selection_left;
  55. vector<int> selection_right;
  56. NICE::VVector xLeft;
  57. NICE::VVector xRight;
  58. for ( vector< pair< double, int > >::const_iterator it = values.begin();
  59. it != values.end(); it++ )
  60. {
  61. double value = it->first;
  62. if ( value < threshold )
  63. {
  64. count_left++;
  65. selection_left.push_back( it->second );
  66. NICE::Vector tmp(1,value);
  67. xLeft.push_back( tmp );
  68. }
  69. else
  70. {
  71. count_right++;
  72. selection_right.push_back( it->second );
  73. NICE::Vector tmp2(1,value);
  74. xRight.push_back( tmp2 );
  75. }
  76. }
  77. if ( (count_left == 0) || (count_right == 0) )
  78. return false; // no split
  79. if ( (count_left < min_examples) || (count_right < min_examples) )
  80. return false; // no split
  81. NICE::Vector yLeft (count_left);
  82. for ( int i = 0; i < count_left; i++ ){
  83. yLeft[i] = y[selection_left[i]];
  84. }
  85. computeLinearLSError(xLeft, yLeft, count_left, error_left);
  86. NICE::Vector yRight (count_right);
  87. for ( int i = 0; i < count_right; i++ ){
  88. yRight[i] = y[selection_right[i]];
  89. }
  90. computeLinearLSError(xRight, yRight, count_right, error_right);
  91. return true;
  92. }
  93. RegressionNode *RTBLinear::buildRecursive ( const NICE::VVector & x,
  94. const NICE::Vector & y,
  95. std::vector<int> & selection,
  96. int depth)
  97. {
  98. #ifdef DEBUGTREE
  99. fprintf (stderr, "Examples: %d (depth %d)\n", (int)selection.size(),
  100. (int)depth);
  101. #endif
  102. RegressionNode *node = new RegressionNode ();
  103. // node->nodePrediction( y, selection );
  104. double lsError;
  105. computeLinearLSError( x, y, (int)x.size(), lsError);
  106. if ( depth > max_depth )
  107. {
  108. #ifdef DEBUGTREE
  109. fprintf (stderr, "RTBLinear: maxmimum depth reached !\n");
  110. #endif
  111. node->trainExamplesIndices = selection;
  112. return node;
  113. }
  114. if ( (int)selection.size() < min_examples )
  115. {
  116. #ifdef DEBUGTREE
  117. fprintf (stderr, "RTBLinear: minimum examples reached %d < %d !\n",
  118. (int)selection.size(), min_examples );
  119. #endif
  120. node->trainExamplesIndices = selection;
  121. return node;
  122. }
  123. int best_feature = 0;
  124. double best_threshold = 0.0;
  125. double best_reduct = -1.0;
  126. // vector<pair<double, int> > best_values;
  127. vector<pair<double, int> > values;
  128. double lsError_left = 0.0;
  129. double lsError_right = 0.0;
  130. for ( int k = 0; k < random_features; k++ )
  131. {
  132. #ifdef DETAILTREE
  133. fprintf (stderr, "calculating random feature %d\n", k );
  134. #endif
  135. int f = rand() % x[0].size();
  136. values.clear();
  137. collectFeatureValues ( x, selection, f, values );
  138. double minValue = (min_element ( values.begin(), values.end() ))->first;
  139. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  140. #ifdef DETAILTREE
  141. fprintf (stderr, "max %f min %f\n", maxValue, minValue );
  142. #endif
  143. if ( maxValue - minValue < 1e-7 ) continue;
  144. for ( int i = 0; i < random_split_tests; i++ )
  145. {
  146. double threshold;
  147. threshold = rand() * (maxValue -minValue ) / RAND_MAX + minValue;
  148. #ifdef DETAILTREE
  149. fprintf (stderr, "calculating split f/s(f) %d/%d %f\n", k, i, threshold );
  150. #endif
  151. lsError_left = 0.0;
  152. lsError_right = 0.0;
  153. int count_left, count_right;
  154. if ( ! errorReductionLeftRight( values, y, threshold, lsError_left,
  155. lsError_right, count_left, count_right) )
  156. continue;
  157. //double pl = (count_left) / (count_left +count_right);
  158. //double errorReduction = lsError - pl*lsError_left - (1-pl)*lsError_right;
  159. double errorReduction = lsError - lsError_left - lsError_right;
  160. if ( errorReduction > best_reduct )
  161. {
  162. best_reduct = errorReduction;
  163. best_threshold = threshold;
  164. best_feature = f;
  165. #ifdef DETAILTREE
  166. fprintf (stderr, "t %f for feature %i\n", best_threshold, best_feature );
  167. #endif
  168. }
  169. }
  170. }
  171. if ( best_reduct < minimum_error_reduction )
  172. {
  173. #ifdef DEBUGTREE
  174. fprintf (stderr, "RTBLinear: error reduction to small !\n");
  175. #endif
  176. node->trainExamplesIndices = selection;
  177. return node;
  178. }
  179. node->f = best_feature;
  180. node->threshold = best_threshold;
  181. // re calculating examples_left and examples_right
  182. vector<int> best_examples_left;
  183. vector<int> best_examples_right;
  184. values.clear();
  185. collectFeatureValues( x, selection, best_feature, values);
  186. best_examples_left.reserve ( values.size() / 2 );
  187. best_examples_right.reserve ( values.size() / 2 );
  188. for ( vector< pair < double, int > >::const_iterator it = values.begin();
  189. it != values.end(); it++ )
  190. {
  191. double value = it->first;
  192. if ( value < best_threshold )
  193. best_examples_left.push_back( it->second );
  194. else
  195. best_examples_right.push_back( it->second );
  196. }
  197. node->left = buildRecursive( x, y, best_examples_left, depth+1 );
  198. node->right = buildRecursive( x, y, best_examples_right, depth+1 );
  199. return node;
  200. }
  201. RegressionNode *RTBLinear::build( const NICE::VVector & x,
  202. const NICE::Vector & y )
  203. {
  204. int index = 0;
  205. vector<int> all;
  206. all.reserve ( y.size() );
  207. for ( uint i = 0; i < y.size(); i++ )
  208. {
  209. all.push_back( index );
  210. index++;
  211. }
  212. return buildRecursive( x, y, all, 0);
  213. }