DTBRandom.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. /**
  2. * @file DTBRandom.cpp
  3. * @brief random decision tree
  4. * @author Erik Rodner
  5. * @date 05/06/2008
  6. */
  7. #include <iostream>
  8. #include <time.h>
  9. #include "vislearning/classifier/fpclassifier/randomforest/DTBRandom.h"
  10. using namespace OBJREC;
  11. #undef DEBUGTREE
  12. #undef DETAILTREE
  13. using namespace std;
  14. using namespace NICE;
  15. DTBRandom::DTBRandom( const Config *conf, std::string section )
  16. {
  17. random_split_tests = conf->gI(section, "random_split_tests", 10 );
  18. random_features = conf->gI(section, "random_features", 500 );
  19. max_depth = conf->gI(section, "max_depth", 10 );
  20. minimum_information_gain = conf->gD(section, "minimum_information_gain", 10e-7 );
  21. minimum_entropy = conf->gD(section, "minimum_entropy", 10e-5 );
  22. use_shannon_entropy = conf->gB(section, "use_shannon_entropy", false );
  23. min_examples = conf->gI(section, "min_examples", 50);
  24. save_indices = conf->gB(section, "save_indices", false);
  25. if ( conf->gB(section, "start_random_generator", false ) )
  26. srand(time(NULL));
  27. }
  28. DTBRandom::~DTBRandom()
  29. {
  30. }
  31. bool DTBRandom::entropyLeftRight ( const FeatureValuesUnsorted & values,
  32. double threshold,
  33. double* stat_left,
  34. double* stat_right,
  35. double & entropy_left,
  36. double & entropy_right,
  37. double & count_left,
  38. double & count_right,
  39. int maxClassNo )
  40. {
  41. count_left = 0;
  42. count_right = 0;
  43. for ( FeatureValuesUnsorted::const_iterator i = values.begin(); i != values.end(); i++ )
  44. {
  45. int classno = i->second;
  46. double value = i->first;
  47. if ( value < threshold ) {
  48. stat_left[classno] += i->fourth;
  49. count_left+=i->fourth;
  50. }
  51. else
  52. {
  53. stat_right[classno] += i->fourth;
  54. count_right+=i->fourth;
  55. }
  56. }
  57. if ( (count_left == 0) || (count_right == 0) )
  58. return false;
  59. entropy_left = 0.0;
  60. for ( int j = 0 ; j <= maxClassNo ; j++ )
  61. if ( stat_left[j] != 0 )
  62. entropy_left -= stat_left[j] * log(stat_left[j]);
  63. entropy_left /= count_left;
  64. entropy_left += log(count_left);
  65. entropy_right = 0.0;
  66. for ( int j = 0 ; j <= maxClassNo ; j++ )
  67. if ( stat_right[j] != 0 )
  68. entropy_right -= stat_right[j] * log(stat_right[j]);
  69. entropy_right /= count_right;
  70. entropy_right += log (count_right);
  71. return true;
  72. }
  73. DecisionNode *DTBRandom::buildRecursive ( const FeaturePool & fp,
  74. const Examples & examples,
  75. vector<int> & examples_selection,
  76. FullVector & distribution,
  77. double e,
  78. int maxClassNo,
  79. int depth )
  80. {
  81. #ifdef DEBUGTREE
  82. fprintf (stderr, "Examples: %d (depth %d)\n", (int)examples_selection.size(),
  83. (int)depth);
  84. #endif
  85. DecisionNode *node = new DecisionNode ();
  86. node->distribution = distribution;
  87. if ( depth > max_depth ) {
  88. #ifdef DEBUGTREE
  89. fprintf (stderr, "DTBRandom: maxmimum depth reached !\n");
  90. #endif
  91. node->trainExamplesIndices = examples_selection;
  92. return node;
  93. }
  94. if ( (int)examples_selection.size() < min_examples ) {
  95. #ifdef DEBUGTREE
  96. fprintf (stderr, "DTBRandom: minimum examples reached %d < %d !\n",
  97. (int)examples_selection.size(), min_examples );
  98. #endif
  99. node->trainExamplesIndices = examples_selection;
  100. return node;
  101. }
  102. // REALLY BAD FIXME
  103. if ( (e <= minimum_entropy) && (e != 0.0) ) {
  104. //if ( e <= minimum_entropy ) {
  105. #ifdef DEBUGTREE
  106. fprintf (stderr, "DTBRandom: minimum entropy reached !\n");
  107. #endif
  108. node->trainExamplesIndices = examples_selection;
  109. return node;
  110. }
  111. Feature *best_feature = NULL;
  112. double best_threshold = 0.0;
  113. double best_ig = -1.0;
  114. FeatureValuesUnsorted best_values;
  115. FeatureValuesUnsorted values;
  116. double *best_distribution_left = new double [maxClassNo+1];
  117. double *best_distribution_right = new double [maxClassNo+1];
  118. double *distribution_left = new double [maxClassNo+1];
  119. double *distribution_right = new double [maxClassNo+1];
  120. double best_entropy_left = 0.0;
  121. double best_entropy_right = 0.0;
  122. for ( int k = 0 ; k < random_features ; k++ )
  123. {
  124. #ifdef DETAILTREE
  125. fprintf (stderr, "calculating random feature %d\n", k );
  126. #endif
  127. Feature *f = fp.getRandomFeature ();
  128. values.clear();
  129. f->calcFeatureValues ( examples, examples_selection, values );
  130. double minValue = (min_element ( values.begin(), values.end() ))->first;
  131. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  132. #ifdef DETAILTREE
  133. fprintf (stderr, "max %f min %f\n", maxValue, minValue );
  134. #endif
  135. if ( maxValue - minValue < 1e-7 ) continue;
  136. for ( int i = 0 ; i < random_split_tests ; i++ )
  137. {
  138. double threshold;
  139. threshold = rand() * (maxValue - minValue ) / RAND_MAX + minValue;
  140. #ifdef DETAILTREE
  141. fprintf (stderr, "calculating split f/s(f) %d/%d %f\n", k, i, threshold );
  142. #endif
  143. double el, er;
  144. // clear distribution
  145. for ( int k = 0 ; k <= maxClassNo ; k++ )
  146. {
  147. distribution_left[k] = 0;
  148. distribution_right[k] = 0;
  149. }
  150. double count_left;
  151. double count_right;
  152. if ( ! entropyLeftRight ( values, threshold,
  153. distribution_left, distribution_right,
  154. el, er, count_left, count_right, maxClassNo ) )
  155. continue;
  156. double pl = (count_left) / (count_left + count_right);
  157. double ig = e - pl*el - (1-pl)*er;
  158. if ( use_shannon_entropy )
  159. {
  160. double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
  161. ig = 2*ig / ( e + esplit );
  162. }
  163. #ifdef DETAILTREE
  164. fprintf (stderr, "ig %f el %f er %f e %f\n", ig, el, er, e );
  165. assert ( ig >= -1e-7 );
  166. #endif
  167. if ( ig > best_ig )
  168. {
  169. best_ig = ig;
  170. best_threshold = threshold;
  171. #ifdef DETAILTREE
  172. fprintf (stderr, "t %f\n", best_threshold );
  173. #endif
  174. best_feature = f;
  175. for ( int k = 0 ; k <= maxClassNo ; k++ )
  176. {
  177. best_distribution_left[k] = distribution_left[k];
  178. best_distribution_right[k] = distribution_right[k];
  179. }
  180. best_entropy_left = el;
  181. best_entropy_right = er;
  182. }
  183. }
  184. }
  185. delete [] distribution_left;
  186. delete [] distribution_right;
  187. if ( best_ig < minimum_information_gain )
  188. {
  189. #ifdef DEBUGTREE
  190. fprintf (stderr, "DTBRandom: minimum information gain reached !\n");
  191. #endif
  192. delete [] best_distribution_left;
  193. delete [] best_distribution_right;
  194. node->trainExamplesIndices = examples_selection;
  195. return node;
  196. }
  197. node->f = best_feature->clone();
  198. node->threshold = best_threshold;
  199. // re calculating examples_left and examples_right
  200. vector<int> best_examples_left;
  201. vector<int> best_examples_right;
  202. values.clear();
  203. best_feature->calcFeatureValues ( examples, examples_selection, values );
  204. best_examples_left.reserve ( values.size() / 2 );
  205. best_examples_right.reserve ( values.size() / 2 );
  206. for ( FeatureValuesUnsorted::const_iterator i = values.begin();
  207. i != values.end();
  208. i++ )
  209. {
  210. double value = i->first;
  211. if ( value < best_threshold ) {
  212. best_examples_left.push_back ( i->third );
  213. } else {
  214. best_examples_right.push_back ( i->third );
  215. }
  216. }
  217. #ifdef DEBUGTREE
  218. node->f->store(cerr);
  219. cerr << endl;
  220. fprintf (stderr, "mutual information / shannon entropy %f entropy %f, left entropy %f right entropy %f\n", best_ig, e, best_entropy_left,
  221. best_entropy_right );
  222. #endif
  223. FullVector best_distribution_left_sparse ( distribution.size() );
  224. FullVector best_distribution_right_sparse ( distribution.size() );
  225. for ( int k = 0 ; k <= maxClassNo ; k++ )
  226. {
  227. double l = best_distribution_left[k];
  228. double r = best_distribution_right[k];
  229. if ( l != 0 )
  230. best_distribution_left_sparse[k] = l;
  231. if ( r != 0 )
  232. best_distribution_right_sparse[k] = r;
  233. #ifdef DEBUGTREE
  234. if ( (l>0)||(r>0) )
  235. fprintf (stderr, "DTBRandom: split of class %d (%f <-> %f)\n", k, l, r );
  236. #endif
  237. }
  238. delete [] best_distribution_left;
  239. delete [] best_distribution_right;
  240. node->left = buildRecursive ( fp, examples, best_examples_left,
  241. best_distribution_left_sparse, best_entropy_left, maxClassNo, depth+1 );
  242. node->right = buildRecursive ( fp, examples, best_examples_right,
  243. best_distribution_right_sparse, best_entropy_right, maxClassNo, depth+1 );
  244. return node;
  245. }
  246. DecisionNode *DTBRandom::build ( const FeaturePool & fp,
  247. const Examples & examples,
  248. int maxClassNo )
  249. {
  250. int index = 0;
  251. fprintf (stderr, "Feature Statistics (Geurts et al.): N=%d sqrt(N)=%lf K=%d\n",
  252. (int)fp.size(), sqrt((double)fp.size()), random_split_tests*random_features );
  253. FullVector distribution ( maxClassNo+1 );
  254. vector<int> all;
  255. all.reserve ( examples.size() );
  256. for ( Examples::const_iterator j = examples.begin();
  257. j != examples.end();
  258. j++ )
  259. {
  260. int classno = j->first;
  261. distribution[classno] += j->second.weight;
  262. all.push_back ( index );
  263. index++;
  264. }
  265. double entropy = 0.0;
  266. double sum = 0.0;
  267. for ( int i = 0 ; i < distribution.size(); i++ )
  268. {
  269. double val = distribution[i];
  270. if ( val <= 0.0 ) continue;
  271. entropy -= val*log(val);
  272. sum += val;
  273. }
  274. entropy /= sum;
  275. entropy += log(sum);
  276. return buildRecursive ( fp, examples, all, distribution, entropy, maxClassNo, 0 );
  277. }