RTBGrid.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * @file RTBGrid.cpp
  3. * @brief random regression tree
  4. * @author Sven Sickert
  5. * @date 07/15/2013
  6. */
  7. #include <iostream>
  8. #include "RTBGrid.h"
  9. using namespace OBJREC;
  10. #undef DEBUGTREE
  11. #undef DETAILTREE
  12. using namespace std;
  13. using namespace NICE;
  14. RTBGrid::RTBGrid( const Config *conf, std::string section )
  15. {
  16. max_depth = conf->gI(section, "max_depth", 20 );
  17. min_examples = conf->gI(section, "min_examples", 10);
  18. save_indices = conf->gB(section, "save_indices", false);
  19. if ( conf->gB(section, "start_random_generator", false ) )
  20. srand(time(NULL));
  21. }
  22. RTBGrid::~RTBGrid()
  23. {
  24. }
  25. bool RTBGrid::balancingLeftRight(const vector< pair< double, int > > values,
  26. double threshold,
  27. int& count_left,
  28. int& count_right)
  29. {
  30. count_left = 0;
  31. count_right = 0;
  32. for ( vector< pair< double, int > >::const_iterator it = values.begin();
  33. it != values.end(); it++ )
  34. {
  35. double value = it->first;
  36. if ( value < threshold )
  37. {
  38. count_left++;
  39. }
  40. else
  41. {
  42. count_right++;
  43. }
  44. }
  45. #ifdef DETAILTREE
  46. fprintf (stderr, "left vs. right: %d : %d\n", count_left, count_right );
  47. #endif
  48. if ( (count_left == 0) || (count_right == 0) )
  49. return false; // no split
  50. return true;
  51. }
  52. RegressionNode *RTBGrid::buildRecursive ( const NICE::VVector & x,
  53. const std::vector<std::vector<double> > & limits,
  54. std::vector<int> & selection,
  55. int depth)
  56. {
  57. #ifdef DEBUGTREE
  58. fprintf (stderr, "Examples: %d (depth %d)\n", (int)selection.size(),
  59. (int)depth);
  60. #endif
  61. RegressionNode *node = new RegressionNode ();
  62. if ( depth > max_depth )
  63. {
  64. #ifdef DEBUGTREE
  65. fprintf (stderr, "RTBGrid: maxmimum depth reached !\n");
  66. #endif
  67. node->trainExamplesIndices = selection;
  68. return node;
  69. }
  70. if ( (int)selection.size() < min_examples )
  71. {
  72. #ifdef DEBUGTREE
  73. fprintf (stderr, "RTBGrid: minimum examples reached %d < %d !\n",
  74. (int)selection.size(), min_examples );
  75. #endif
  76. node->trainExamplesIndices = selection;
  77. return node;
  78. }
  79. vector<pair<double, int> > values;
  80. int f = depth % x[0].size();
  81. values.clear();
  82. collectFeatureValues ( x, selection, f, values );
  83. #ifdef DETAILTREE
  84. double minValue = (min_element ( values.begin(), values.end() ))->first;
  85. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  86. fprintf (stderr, "max %f min %f\n", maxValue, minValue );
  87. #endif
  88. double threshold = 0.5 * (limits[f][0]+limits[f][1]);
  89. int tmp = depth;
  90. while( tmp > (int)x[0].size() )
  91. {
  92. threshold *= 0.5;
  93. tmp -= x[0].size();
  94. }
  95. int count_left, count_right;
  96. if ( ! balancingLeftRight( values, threshold, count_left, count_right) )
  97. {
  98. fprintf ( stderr, "RTBGrid: no split possible (empty leaf)\n" );
  99. node->trainExamplesIndices = selection;
  100. return node;
  101. }
  102. #ifdef DETAILTREE
  103. fprintf (stderr, "t %f for feature %d\n", threshold, f );
  104. #endif
  105. node->f = f;
  106. node->threshold = threshold;
  107. // re calculating examples_left and examples_right
  108. vector<int> best_examples_left;
  109. vector<int> best_examples_right;
  110. best_examples_left.reserve ( values.size() / 2 );
  111. best_examples_right.reserve ( values.size() / 2 );
  112. for ( vector< pair < double, int > >::const_iterator it = values.begin();
  113. it != values.end(); it++ )
  114. {
  115. double value = it->first;
  116. if ( value < threshold )
  117. best_examples_left.push_back( it->second );
  118. else
  119. best_examples_right.push_back( it->second );
  120. }
  121. node->left = buildRecursive( x, limits, best_examples_left, depth+1 );
  122. node->right = buildRecursive( x, limits, best_examples_right, depth+1 );
  123. return node;
  124. }
  125. RegressionNode *RTBGrid::build( const NICE::VVector & x,
  126. const NICE::Vector & y )
  127. {
  128. int index = 0;
  129. vector<int> all;
  130. all.reserve ( y.size() );
  131. for ( uint i = 0; i < y.size(); i++ )
  132. {
  133. all.push_back( index );
  134. index++;
  135. }
  136. // get min/max values for all features
  137. int fcount = x[0].size();
  138. vector< vector<double> > limits;
  139. for ( int j = 0; j < fcount; j++ )
  140. {
  141. double min = numeric_limits<double>::max();
  142. double max = numeric_limits<double>::min();
  143. for ( int i = 0; i < x.size(); i++ )
  144. {
  145. double value = x[i][j];
  146. if (value > max ) max = value;
  147. if (value < min ) min = value;
  148. }
  149. vector<double> flimit;
  150. flimit.push_back(min);
  151. flimit.push_back(max);
  152. limits.push_back(flimit);
  153. }
  154. return buildRecursive( x, limits, all, 0);
  155. }