DTBClusterRandom.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * @file DTBClusterRandom.cpp
  3. * @brief build a decision tree for clustering
  4. * @author Erik Rodner
  5. * @date 05/01/2010
  6. */
  7. #include <iostream>
  8. #include "vislearning/classifier/fpclassifier/randomforest/DTBClusterRandom.h"
  9. using namespace OBJREC;
  10. using namespace std;
  11. using namespace NICE;
  12. DTBClusterRandom::DTBClusterRandom( const Config *conf, std::string section )
  13. {
  14. max_depth = conf->gI(section, "max_depth", 10 );
  15. min_examples = conf->gI(section, "min_examples", 50);
  16. }
  17. DTBClusterRandom::~DTBClusterRandom()
  18. {
  19. }
  20. DecisionNode *DTBClusterRandom::buildRecursive ( const FeaturePool & fp,
  21. const Examples & examples,
  22. vector<int> & examples_selection,
  23. FullVector & distribution,
  24. int maxClassNo,
  25. int depth )
  26. {
  27. #ifdef DEBUGTREE
  28. fprintf (stderr, "Examples: %d (depth %d)\n", (int)examples_selection.size(),
  29. (int)depth);
  30. #endif
  31. DecisionNode *node = new DecisionNode ();
  32. node->distribution = distribution;
  33. if ( depth > max_depth ) {
  34. #ifdef DEBUGTREE
  35. fprintf (stderr, "DTBClusterRandom: maxmimum depth reached !\n");
  36. #endif
  37. return node;
  38. }
  39. if ( (int)examples_selection.size() < min_examples ) {
  40. #ifdef DEBUGTREE
  41. fprintf (stderr, "DTBClusterRandom: minimum examples reached %d < %d !\n",
  42. (int)examples_selection.size(), min_examples );
  43. #endif
  44. return node;
  45. }
  46. Feature *f = fp.getRandomFeature ();
  47. FeatureValuesUnsorted values;
  48. values.clear();
  49. f->calcFeatureValues ( examples, examples_selection, values );
  50. sort ( values.begin(), values.end() );
  51. double minValue = (values.begin())->first;
  52. double maxValue = (values.begin()+(values.size()-1))->first;
  53. double median = (values.begin()+values.size()/2)->first;
  54. #ifdef DETAILTREE
  55. fprintf (stderr, "max %f min %f median %f\n", maxValue, minValue, median );
  56. #endif
  57. if ( maxValue - minValue < 1e-30 ) {
  58. #ifdef DEBUGTREE
  59. cerr << "DTBClusterRandom: max - min < 1e-30: exit" << endl;
  60. #endif
  61. return node;
  62. }
  63. node->f = f->clone();
  64. node->threshold = median;
  65. // re calculating examples_left and examples_right
  66. vector<int> best_examples_left;
  67. vector<int> best_examples_right;
  68. FullVector best_distribution_left (maxClassNo+1);
  69. FullVector best_distribution_right (maxClassNo+1);
  70. best_distribution_left.set(0.0);
  71. best_distribution_right.set(0.0);
  72. best_examples_left.reserve ( values.size() / 2 );
  73. best_examples_right.reserve ( values.size() / 2 );
  74. for ( FeatureValuesUnsorted::const_iterator i = values.begin();
  75. i != values.end();
  76. i++ )
  77. {
  78. double value = i->first;
  79. int classno = i->second;
  80. if ( value < median ) {
  81. best_examples_left.push_back ( i->third );
  82. best_distribution_left[classno]++;
  83. } else {
  84. best_examples_right.push_back ( i->third );
  85. best_distribution_right[classno]++;
  86. }
  87. }
  88. #ifdef DEBUGTREE
  89. node->f->store(cerr);
  90. cerr << endl;
  91. #endif
  92. node->left = buildRecursive ( fp, examples, best_examples_left,
  93. best_distribution_left, maxClassNo, depth+1 );
  94. node->right = buildRecursive ( fp, examples, best_examples_right,
  95. best_distribution_right, maxClassNo, depth+1 );
  96. return node;
  97. }
  98. DecisionNode *DTBClusterRandom::build ( const FeaturePool & fp,
  99. const Examples & examples,
  100. int maxClassNo )
  101. {
  102. int index = 0;
  103. FullVector distribution ( maxClassNo+1 );
  104. vector<int> all;
  105. all.reserve ( examples.size() );
  106. for ( Examples::const_iterator j = examples.begin();
  107. j != examples.end();
  108. j++ )
  109. {
  110. int classno = j->first;
  111. distribution[classno] += j->second.weight;
  112. all.push_back ( index );
  113. index++;
  114. }
  115. return buildRecursive ( fp, examples, all, distribution, maxClassNo, 0 );
  116. }