RTBClusterRandom.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. /**
  2. * @file RTBClusterRandom.cpp
  3. * @brief random regression tree
  4. * @author Sven Sickert
  5. * @date 07/19/2013
  6. */
  7. #include <iostream>
  8. #include "RTBClusterRandom.h"
  9. using namespace OBJREC;
  10. #undef DEBUGTREE
  11. #undef DETAILTREE
  12. using namespace std;
  13. using namespace NICE;
  14. RTBClusterRandom::RTBClusterRandom( const Config *conf, std::string section )
  15. {
  16. max_depth = conf->gI(section, "max_depth", 20 );
  17. min_examples = conf->gI(section, "min_examples", 10);
  18. save_indices = conf->gB(section, "save_indices", false);
  19. if ( conf->gB(section, "start_random_generator", false ) )
  20. srand(time(NULL));
  21. }
  22. RTBClusterRandom::~RTBClusterRandom()
  23. {
  24. }
  25. bool RTBClusterRandom::balancingLeftRight(const vector< pair< double, int > > values,
  26. double threshold,
  27. int& count_left,
  28. int& count_right)
  29. {
  30. count_left = 0;
  31. count_right = 0;
  32. for ( vector< pair< double, int > >::const_iterator it = values.begin();
  33. it != values.end(); it++ )
  34. {
  35. double value = it->first;
  36. if ( value < threshold )
  37. {
  38. count_left++;
  39. }
  40. else
  41. {
  42. count_right++;
  43. }
  44. }
  45. #ifdef DETAILTREE
  46. fprintf (stderr, "left vs. right: %d : %d\n", count_left, count_right );
  47. #endif
  48. if ( (count_left == 0) || (count_right == 0) )
  49. return false; // no split
  50. return true;
  51. }
  52. RegressionNode *RTBClusterRandom::buildRecursive ( const NICE::VVector & x,
  53. const NICE::Vector & y,
  54. std::vector<int> & selection,
  55. int depth)
  56. {
  57. #ifdef DEBUGTREE
  58. fprintf (stderr, "Examples: %d (depth %d)\n", (int)selection.size(),
  59. (int)depth);
  60. #endif
  61. RegressionNode *node = new RegressionNode ();
  62. node->nodePrediction( y, selection );
  63. double lsError = node->lsError;
  64. if ( depth > max_depth )
  65. {
  66. #ifdef DEBUGTREE
  67. fprintf (stderr, "RTBClusterRandom: maxmimum depth reached !\n");
  68. #endif
  69. node->trainExamplesIndices = selection;
  70. return node;
  71. }
  72. if ( (int)selection.size() < min_examples )
  73. {
  74. #ifdef DEBUGTREE
  75. fprintf (stderr, "RTBClusterRandom: minimum examples reached %d < %d !\n",
  76. (int)selection.size(), min_examples );
  77. #endif
  78. node->trainExamplesIndices = selection;
  79. return node;
  80. }
  81. vector<pair<double, int> > values;
  82. int f = rand() % x[0].size();
  83. values.clear();
  84. collectFeatureValues ( x, selection, f, values );
  85. double median = (values.begin() + values.size() / 2)->first;
  86. #ifdef DETAILTREE
  87. double minValue = (min_element ( values.begin(), values.end() ))->first;
  88. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  89. fprintf (stderr, "max %f min %f med %f\n", maxValue, minValue, median );
  90. #endif
  91. int count_left, count_right;
  92. if ( ! balancingLeftRight( values, median, count_left, count_right) )
  93. {
  94. fprintf ( stderr, "RTBClusterRandom: no split possible (empty leaf)\n" );
  95. node->trainExamplesIndices = selection;
  96. return node;
  97. }
  98. #ifdef DETAILTREE
  99. fprintf (stderr, "t %f for feature %d\n", median, f );
  100. #endif
  101. node->f = f;
  102. node->threshold = median;
  103. // re calculating examples_left and examples_right
  104. vector<int> best_examples_left;
  105. vector<int> best_examples_right;
  106. best_examples_left.reserve ( values.size() / 2 );
  107. best_examples_right.reserve ( values.size() / 2 );
  108. for ( vector< pair < double, int > >::const_iterator it = values.begin();
  109. it != values.end(); it++ )
  110. {
  111. double value = it->first;
  112. if ( value < median )
  113. best_examples_left.push_back( it->second );
  114. else
  115. best_examples_right.push_back( it->second );
  116. }
  117. node->left = buildRecursive( x, y, best_examples_left, depth+1 );
  118. node->right = buildRecursive( x, y, best_examples_right, depth+1 );
  119. return node;
  120. }
  121. RegressionNode *RTBClusterRandom::build( const NICE::VVector & x,
  122. const NICE::Vector & y )
  123. {
  124. int index = 0;
  125. vector<int> all;
  126. all.reserve ( y.size() );
  127. for ( uint i = 0; i < y.size(); i++ )
  128. {
  129. all.push_back( index );
  130. index++;
  131. }
  132. return buildRecursive( x, y, all, 0);
  133. }