DTBClusterRandom.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. /**
  2. * @file DTBClusterRandom.cpp
  3. * @brief build a decision tree for clustering
  4. * @author Erik Rodner
  5. * @date 05/01/2010
  6. */
  7. #ifdef NOVISUAL
  8. #include <vislearning/nice_nonvis.h>
  9. #else
  10. #include <vislearning/nice.h>
  11. #endif
  12. #include <iostream>
  13. #include "vislearning/classifier/fpclassifier/randomforest/DTBClusterRandom.h"
  14. using namespace OBJREC;
  15. using namespace std;
  16. using namespace NICE;
  17. DTBClusterRandom::DTBClusterRandom( const Config *conf, std::string section )
  18. {
  19. max_depth = conf->gI(section, "max_depth", 10 );
  20. min_examples = conf->gI(section, "min_examples", 50);
  21. }
  22. DTBClusterRandom::~DTBClusterRandom()
  23. {
  24. }
  25. DecisionNode *DTBClusterRandom::buildRecursive ( const FeaturePool & fp,
  26. const Examples & examples,
  27. vector<int> & examples_selection,
  28. FullVector & distribution,
  29. int maxClassNo,
  30. int depth )
  31. {
  32. #ifdef DEBUGTREE
  33. fprintf (stderr, "Examples: %d (depth %d)\n", (int)examples_selection.size(),
  34. (int)depth);
  35. #endif
  36. DecisionNode *node = new DecisionNode ();
  37. node->distribution = distribution;
  38. if ( depth > max_depth ) {
  39. #ifdef DEBUGTREE
  40. fprintf (stderr, "DTBClusterRandom: maxmimum depth reached !\n");
  41. #endif
  42. return node;
  43. }
  44. if ( (int)examples_selection.size() < min_examples ) {
  45. #ifdef DEBUGTREE
  46. fprintf (stderr, "DTBClusterRandom: minimum examples reached %d < %d !\n",
  47. (int)examples_selection.size(), min_examples );
  48. #endif
  49. return node;
  50. }
  51. Feature *f = fp.getRandomFeature ();
  52. FeatureValuesUnsorted values;
  53. values.clear();
  54. f->calcFeatureValues ( examples, examples_selection, values );
  55. sort ( values.begin(), values.end() );
  56. double minValue = (values.begin())->first;
  57. double maxValue = (values.begin()+(values.size()-1))->first;
  58. double median = (values.begin()+values.size()/2)->first;
  59. #ifdef DETAILTREE
  60. fprintf (stderr, "max %f min %f median %f\n", maxValue, minValue, median );
  61. #endif
  62. if ( maxValue - minValue < 1e-30 ) {
  63. #ifdef DEBUGTREE
  64. cerr << "DTBClusterRandom: max - min < 1e-30: exit" << endl;
  65. #endif
  66. return node;
  67. }
  68. node->f = f->clone();
  69. node->threshold = median;
  70. // re calculating examples_left and examples_right
  71. vector<int> best_examples_left;
  72. vector<int> best_examples_right;
  73. FullVector best_distribution_left (maxClassNo+1);
  74. FullVector best_distribution_right (maxClassNo+1);
  75. best_distribution_left.set(0.0);
  76. best_distribution_right.set(0.0);
  77. best_examples_left.reserve ( values.size() / 2 );
  78. best_examples_right.reserve ( values.size() / 2 );
  79. for ( FeatureValuesUnsorted::const_iterator i = values.begin();
  80. i != values.end();
  81. i++ )
  82. {
  83. double value = i->first;
  84. int classno = i->second;
  85. if ( value < median ) {
  86. best_examples_left.push_back ( i->third );
  87. best_distribution_left[classno]++;
  88. } else {
  89. best_examples_right.push_back ( i->third );
  90. best_distribution_right[classno]++;
  91. }
  92. }
  93. #ifdef DEBUGTREE
  94. node->f->store(cerr);
  95. cerr << endl;
  96. #endif
  97. node->left = buildRecursive ( fp, examples, best_examples_left,
  98. best_distribution_left, maxClassNo, depth+1 );
  99. node->right = buildRecursive ( fp, examples, best_examples_right,
  100. best_distribution_right, maxClassNo, depth+1 );
  101. return node;
  102. }
  103. DecisionNode *DTBClusterRandom::build ( const FeaturePool & fp,
  104. const Examples & examples,
  105. int maxClassNo )
  106. {
  107. int index = 0;
  108. FullVector distribution ( maxClassNo+1 );
  109. vector<int> all;
  110. all.reserve ( examples.size() );
  111. for ( Examples::const_iterator j = examples.begin();
  112. j != examples.end();
  113. j++ )
  114. {
  115. int classno = j->first;
  116. distribution[classno] += j->second.weight;
  117. all.push_back ( index );
  118. index++;
  119. }
  120. return buildRecursive ( fp, examples, all, distribution, maxClassNo, 0 );
  121. }