DTBRandomOblique.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. /**
  2. * @file DTBRandomOblique.cpp
  3. * @brief random oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #include <iostream>
  8. #include <time.h>
  9. #include "DTBRandomOblique.h"
  10. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  11. #include "core/vector/Algorithms.h"
  12. using namespace OBJREC;
  13. #define DEBUGTREE
  14. #undef DETAILTREE
  15. using namespace std;
  16. using namespace NICE;
  17. DTBRandomOblique::DTBRandomOblique ( const Config *conf, string section )
  18. {
  19. random_split_tests = conf->gI(section, "random_split_tests", 10 );
  20. max_depth = conf->gI(section, "max_depth", 10 );
  21. minimum_information_gain = conf->gD(section, "minimum_information_gain", 10e-7 );
  22. minimum_entropy = conf->gD(section, "minimum_entropy", 10e-5 );
  23. use_shannon_entropy = conf->gB(section, "use_shannon_entropy", false );
  24. min_examples = conf->gI(section, "min_examples", 50);
  25. save_indices = conf->gB(section, "save_indices", false);
  26. lambda = conf->gD(section, "lambda", 0.5 );
  27. if ( conf->gB(section, "start_random_generator", false ) )
  28. srand(time(NULL));
  29. }
  30. DTBRandomOblique::~DTBRandomOblique()
  31. {
  32. }
  33. bool DTBRandomOblique::entropyLeftRight (
  34. const FeatureValuesUnsorted & values,
  35. double threshold,
  36. double* stat_left,
  37. double* stat_right,
  38. double & entropy_left,
  39. double & entropy_right,
  40. double & count_left,
  41. double & count_right,
  42. int maxClassNo )
  43. {
  44. count_left = 0;
  45. count_right = 0;
  46. for ( FeatureValuesUnsorted::const_iterator i = values.begin(); i != values.end(); i++ )
  47. {
  48. int classno = i->second;
  49. double value = i->first;
  50. if ( value < threshold ) {
  51. stat_left[classno] += i->fourth;
  52. count_left+=i->fourth;
  53. }
  54. else
  55. {
  56. stat_right[classno] += i->fourth;
  57. count_right+=i->fourth;
  58. }
  59. }
  60. if ( (count_left == 0) || (count_right == 0) )
  61. return false;
  62. entropy_left = 0.0;
  63. for ( int j = 0 ; j <= maxClassNo ; j++ )
  64. if ( stat_left[j] != 0 )
  65. entropy_left -= stat_left[j] * log(stat_left[j]);
  66. entropy_left /= count_left;
  67. entropy_left += log(count_left);
  68. entropy_right = 0.0;
  69. for ( int j = 0 ; j <= maxClassNo ; j++ )
  70. if ( stat_right[j] != 0 )
  71. entropy_right -= stat_right[j] * log(stat_right[j]);
  72. entropy_right /= count_right;
  73. entropy_right += log (count_right);
  74. return true;
  75. }
  76. /** refresh data matrix X and label vector y */
  77. void DTBRandomOblique::getDataAndLabel(
  78. const FeaturePool &fp,
  79. const Examples &examples,
  80. const std::vector<int> &examples_selection,
  81. NICE::Matrix & matX,
  82. NICE::Vector & vecY )
  83. {
  84. ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
  85. int amountParams = f->getParameterLength();
  86. int amountExamples = examples_selection.size();
  87. NICE::Matrix X(amountExamples, amountParams, 0.0 );
  88. NICE::Vector y(amountExamples, 0.0);
  89. int matIndex = 0;
  90. for ( vector<int>::const_iterator si = examples_selection.begin();
  91. si != examples_selection.end();
  92. si++ )
  93. {
  94. const pair<int, Example> & p = examples[*si];
  95. int classno = p.first;
  96. const Example & ce = p.second;
  97. NICE::Vector pixelRepr = f->getFeatureVector( &ce );
  98. pixelRepr /= pixelRepr.Max();
  99. // TODO for multiclass scenarios we need ONEvsALL!
  100. // {0,1} -> {-1,+1}
  101. double label = 2*classno-1;
  102. label *= ce.weight;
  103. pixelRepr *= ce.weight;
  104. y.set( matIndex, label );
  105. X.setRow(matIndex,pixelRepr);
  106. matIndex++;
  107. }
  108. matX = X;
  109. vecY = y;
  110. }
  111. /** recursive building method */
  112. DecisionNode *DTBRandomOblique::buildRecursive(
  113. const FeaturePool & fp,
  114. const Examples & examples,
  115. std::vector<int> & examples_selection,
  116. FullVector & distribution,
  117. double e,
  118. int maxClassNo,
  119. int depth)
  120. {
  121. #ifdef DEBUGTREE
  122. std::cerr << "Examples: " << (int)examples_selection.size()
  123. << " (depth " << (int)depth << ")" << std::endl;
  124. #endif
  125. // initialize new node
  126. DecisionNode *node = new DecisionNode ();
  127. node->distribution = distribution;
  128. // stop criteria: max_depth, min_examples, min_entropy
  129. if ( depth > max_depth
  130. || (int)examples_selection.size() < min_examples
  131. || ( (e <= minimum_entropy) && (e != 0.0) ) ) // FIXME
  132. {
  133. #ifdef DEBUGTREE
  134. std::cerr << "DTBRandomOblique: Stopping criteria applied!" << std::endl;
  135. #endif
  136. node->trainExamplesIndices = examples_selection;
  137. return node;
  138. }
  139. // refresh/set X and y
  140. NICE::Matrix X;
  141. NICE::Vector y;
  142. getDataAndLabel( fp, examples, examples_selection, X, y );
  143. NICE::Matrix XTX = X.transpose()*X;
  144. XTX.addDiagonal ( NICE::Vector( XTX.rows(), lambda) );
  145. NICE::Matrix G;
  146. NICE::Vector beta;
  147. choleskyDecomp(XTX, G);
  148. choleskyInvert(G, XTX);
  149. NICE::Matrix temp = XTX * X.transpose();
  150. beta.multiply(temp,y,false);
  151. // choleskySolve(G, y, beta );
  152. // variables
  153. double best_threshold = 0.0;
  154. double best_ig = -1.0;
  155. FeatureValuesUnsorted values;
  156. double *best_distribution_left = new double [maxClassNo+1];
  157. double *best_distribution_right = new double [maxClassNo+1];
  158. double *distribution_left = new double [maxClassNo+1];
  159. double *distribution_right = new double [maxClassNo+1];
  160. double best_entropy_left = 0.0;
  161. double best_entropy_right = 0.0;
  162. // Setting Convolutional Feature
  163. ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
  164. f->setParameterVector( beta );
  165. // Feature Values
  166. values.clear();
  167. f->calcFeatureValues( examples, examples_selection, values);
  168. double minValue = (min_element ( values.begin(), values.end() ))->first;
  169. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  170. if ( maxValue - minValue < 1e-7 )
  171. std::cerr << "DTBRandomOblique: Difference between min and max of features values to small!" << std::endl;
  172. // randomly chosen thresholds
  173. for ( int i = 0; i < random_split_tests; i++ )
  174. {
  175. double threshold = (i * (maxValue - minValue ) / (double)random_split_tests)
  176. + minValue;
  177. // preparations
  178. double el, er;
  179. for ( int k = 0 ; k <= maxClassNo ; k++ )
  180. {
  181. distribution_left[k] = 0.0;
  182. distribution_right[k] = 0.0;
  183. }
  184. /** Test the current split */
  185. // Does another split make sense?
  186. double count_left;
  187. double count_right;
  188. if ( ! entropyLeftRight ( values, threshold,
  189. distribution_left, distribution_right,
  190. el, er, count_left, count_right, maxClassNo ) )
  191. continue;
  192. // information gain and entropy
  193. double pl = (count_left) / (count_left + count_right);
  194. double ig = e - pl*el - (1-pl)*er;
  195. if ( use_shannon_entropy )
  196. {
  197. double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
  198. ig = 2*ig / ( e + esplit );
  199. }
  200. #ifdef DETAILTREE
  201. std::cerr << "Testing split #" << i << ": t=" << threshold
  202. << " ig=" << ig << std::endl;
  203. #endif
  204. if ( ig > best_ig )
  205. {
  206. best_ig = ig;
  207. best_threshold = threshold;
  208. for ( int k = 0 ; k <= maxClassNo ; k++ )
  209. {
  210. best_distribution_left[k] = distribution_left[k];
  211. best_distribution_right[k] = distribution_right[k];
  212. }
  213. best_entropy_left = el;
  214. best_entropy_right = er;
  215. }
  216. }
  217. //cleaning up
  218. delete [] distribution_left;
  219. delete [] distribution_right;
  220. // stop criteria: minimum information gain
  221. if ( best_ig < minimum_information_gain )
  222. {
  223. #ifdef DEBUGTREE
  224. std::cerr << "DTBRandomOblique: Minimum information gain reached!" << std::endl;
  225. #endif
  226. delete [] best_distribution_left;
  227. delete [] best_distribution_right;
  228. node->trainExamplesIndices = examples_selection;
  229. return node;
  230. }
  231. /** Save the best split to current node */
  232. node->f = f->clone();
  233. node->threshold = best_threshold;
  234. /** Split examples according to split function */
  235. vector<int> examples_left;
  236. vector<int> examples_right;
  237. examples_left.reserve ( values.size() / 2 );
  238. examples_right.reserve ( values.size() / 2 );
  239. for ( FeatureValuesUnsorted::const_iterator i = values.begin();
  240. i != values.end(); i++ )
  241. {
  242. double value = i->first;
  243. if ( value < best_threshold )
  244. examples_left.push_back ( i->third );
  245. else
  246. examples_right.push_back ( i->third );
  247. }
  248. #ifdef DEBUGTREE
  249. node->f->store( std::cerr );
  250. std::cerr << std::endl;
  251. std::cerr << "mutual information / shannon entropy " << best_ig << " entropy "
  252. << e << " left entropy " << best_entropy_left << " right entropy "
  253. << best_entropy_right << std::endl;
  254. #endif
  255. FullVector distribution_left_sparse ( distribution.size() );
  256. FullVector distribution_right_sparse ( distribution.size() );
  257. for ( int k = 0 ; k <= maxClassNo ; k++ )
  258. {
  259. double l = best_distribution_left[k];
  260. double r = best_distribution_right[k];
  261. if ( l != 0 )
  262. distribution_left_sparse[k] = l;
  263. if ( r != 0 )
  264. distribution_right_sparse[k] = r;
  265. #ifdef DEBUGTREE
  266. if ( (l>0)||(r>0) )
  267. {
  268. std::cerr << "DTBRandomOblique: split of class " << k << " ("
  269. << l << " <-> " << r << ") " << std::endl;
  270. }
  271. #endif
  272. }
  273. delete [] best_distribution_left;
  274. delete [] best_distribution_right;
  275. /** Recursion */
  276. // left child
  277. node->left = buildRecursive ( fp, examples, examples_left,
  278. distribution_left_sparse, best_entropy_left,
  279. maxClassNo, depth+1 );
  280. // right child
  281. node->right = buildRecursive ( fp, examples, examples_right,
  282. distribution_right_sparse, best_entropy_right,
  283. maxClassNo, depth+1 );
  284. return node;
  285. }
  286. /** initial building method */
  287. DecisionNode *DTBRandomOblique::build ( const FeaturePool & fp,
  288. const Examples & examples,
  289. int maxClassNo )
  290. {
  291. int index = 0;
  292. FullVector distribution ( maxClassNo+1 );
  293. vector<int> all;
  294. all.reserve ( examples.size() );
  295. for ( Examples::const_iterator j = examples.begin();
  296. j != examples.end(); j++ )
  297. {
  298. int classno = j->first;
  299. distribution[classno] += j->second.weight;
  300. all.push_back ( index );
  301. index++;
  302. }
  303. double entropy = 0.0;
  304. double sum = 0.0;
  305. for ( int i = 0 ; i < distribution.size(); i++ )
  306. {
  307. double val = distribution[i];
  308. if ( val <= 0.0 ) continue;
  309. entropy -= val*log(val);
  310. sum += val;
  311. }
  312. entropy /= sum;
  313. entropy += log(sum);
  314. return buildRecursive ( fp, examples, all, distribution, entropy, maxClassNo, 0 );
  315. }