DTBOblique.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. /**
  2. * @file DTBOblique.cpp
  3. * @brief random oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #include <iostream>
  8. #include <time.h>
  9. #include "DTBOblique.h"
  10. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  11. #include "core/vector/Algorithms.h"
  12. using namespace OBJREC;
  13. #define DEBUGTREE
  14. using namespace std;
  15. using namespace NICE;
  16. DTBOblique::DTBOblique ( const Config *conf, string section )
  17. {
  18. random_split_tests = conf->gI(section, "random_split_tests", 10 );
  19. max_depth = conf->gI(section, "max_depth", 10 );
  20. minimum_information_gain = conf->gD(section, "minimum_information_gain", 10e-7 );
  21. minimum_entropy = conf->gD(section, "minimum_entropy", 10e-5 );
  22. use_shannon_entropy = conf->gB(section, "use_shannon_entropy", false );
  23. min_examples = conf->gI(section, "min_examples", 50);
  24. save_indices = conf->gB(section, "save_indices", false);
  25. lambdaInit = conf->gD(section, "lambdaInit", 0.5 );
  26. }
  27. DTBOblique::~DTBOblique()
  28. {
  29. }
  30. bool DTBOblique::entropyLeftRight (
  31. const FeatureValuesUnsorted & values,
  32. double threshold,
  33. double* stat_left,
  34. double* stat_right,
  35. double & entropy_left,
  36. double & entropy_right,
  37. double & count_left,
  38. double & count_right,
  39. int maxClassNo )
  40. {
  41. count_left = 0;
  42. count_right = 0;
  43. for ( FeatureValuesUnsorted::const_iterator i = values.begin(); i != values.end(); i++ )
  44. {
  45. int classno = i->second;
  46. double value = i->first;
  47. if ( value < threshold ) {
  48. stat_left[classno] += i->fourth;
  49. count_left+=i->fourth;
  50. }
  51. else
  52. {
  53. stat_right[classno] += i->fourth;
  54. count_right+=i->fourth;
  55. }
  56. }
  57. if ( (count_left == 0) || (count_right == 0) )
  58. return false;
  59. entropy_left = 0.0;
  60. for ( int j = 0 ; j <= maxClassNo ; j++ )
  61. if ( stat_left[j] != 0 )
  62. entropy_left -= stat_left[j] * log(stat_left[j]);
  63. entropy_left /= count_left;
  64. entropy_left += log(count_left);
  65. entropy_right = 0.0;
  66. for ( int j = 0 ; j <= maxClassNo ; j++ )
  67. if ( stat_right[j] != 0 )
  68. entropy_right -= stat_right[j] * log(stat_right[j]);
  69. entropy_right /= count_right;
  70. entropy_right += log (count_right);
  71. return true;
  72. }
  73. /** refresh data matrix X and label vector y */
  74. void DTBOblique::getDataAndLabel(
  75. const FeaturePool &fp,
  76. const Examples &examples,
  77. const std::vector<int> &examples_selection,
  78. NICE::Matrix & matX,
  79. NICE::Vector & vecY )
  80. {
  81. ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
  82. int amountParams = f->getParameterLength();
  83. int amountExamples = examples_selection.size();
  84. NICE::Matrix X(amountExamples, amountParams, 0.0 );
  85. NICE::Vector y(amountExamples, 0.0);
  86. int matIndex = 0;
  87. for ( vector<int>::const_iterator si = examples_selection.begin();
  88. si != examples_selection.end();
  89. si++ )
  90. {
  91. const pair<int, Example> & p = examples[*si];
  92. int classno = p.first;
  93. const Example & ce = p.second;
  94. NICE::Vector pixelRepr = f->getFeatureVector( &ce );
  95. pixelRepr /= pixelRepr.Max();
  96. // TODO for multiclass scenarios we need ONEvsALL!
  97. // {0,1} -> {-1,+1}
  98. double label = 2*classno-1;
  99. label *= ce.weight;
  100. pixelRepr *= ce.weight;
  101. y.set( matIndex, label );
  102. X.setRow(matIndex,pixelRepr);
  103. matIndex++;
  104. }
  105. matX = X;
  106. vecY = y;
  107. }
  108. /** recursive building method */
  109. DecisionNode *DTBOblique::buildRecursive(
  110. const FeaturePool & fp,
  111. const Examples & examples,
  112. std::vector<int> & examples_selection,
  113. FullVector & distribution,
  114. double e,
  115. int maxClassNo,
  116. int depth,
  117. double lambdaCurrent )
  118. {
  119. #ifdef DEBUGTREE
  120. std::cerr << "Examples: " << (int)examples_selection.size()
  121. << " (depth " << (int)depth << ")" << std::endl;
  122. #endif
  123. // initialize new node
  124. DecisionNode *node = new DecisionNode ();
  125. node->distribution = distribution;
  126. // stop criteria: max_depth, min_examples, min_entropy
  127. if ( depth > max_depth
  128. || (int)examples_selection.size() < min_examples
  129. || ( (e <= minimum_entropy) && (e != 0.0) ) ) // FIXME
  130. {
  131. #ifdef DEBUGTREE
  132. std::cerr << "DTBOblique: Stopping criteria applied!" << std::endl;
  133. #endif
  134. node->trainExamplesIndices = examples_selection;
  135. return node;
  136. }
  137. // refresh/set X and y
  138. NICE::Matrix X, G;
  139. NICE::Vector y, beta;
  140. getDataAndLabel( fp, examples, examples_selection, X, y );
  141. // least squares solution
  142. NICE::Matrix XTX = X.transpose()*X;
  143. XTX.addDiagonal ( NICE::Vector( XTX.rows(), lambdaCurrent) );
  144. choleskyDecomp(XTX, G);
  145. choleskyInvert(G, XTX);
  146. NICE::Matrix temp = XTX * X.transpose();
  147. beta.multiply(temp,y,false);
  148. // variables
  149. double best_threshold = 0.0;
  150. double best_ig = -1.0;
  151. FeatureValuesUnsorted values;
  152. double *best_distribution_left = new double [maxClassNo+1];
  153. double *best_distribution_right = new double [maxClassNo+1];
  154. double *distribution_left = new double [maxClassNo+1];
  155. double *distribution_right = new double [maxClassNo+1];
  156. double best_entropy_left = 0.0;
  157. double best_entropy_right = 0.0;
  158. // Setting Convolutional Feature
  159. ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
  160. f->setParameterVector( beta );
  161. // Feature Values
  162. values.clear();
  163. f->calcFeatureValues( examples, examples_selection, values);
  164. double minValue = (min_element ( values.begin(), values.end() ))->first;
  165. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  166. if ( maxValue - minValue < 1e-7 )
  167. std::cerr << "DTBOblique: Difference between min and max of features values to small!" << std::endl;
  168. // get best thresholds by complete search
  169. for ( int i = 0; i < random_split_tests; i++ )
  170. {
  171. double threshold = (i * (maxValue - minValue ) / (double)random_split_tests)
  172. + minValue;
  173. // preparations
  174. double el, er;
  175. for ( int k = 0 ; k <= maxClassNo ; k++ )
  176. {
  177. distribution_left[k] = 0.0;
  178. distribution_right[k] = 0.0;
  179. }
  180. /** Test the current split */
  181. // Does another split make sense?
  182. double count_left;
  183. double count_right;
  184. if ( ! entropyLeftRight ( values, threshold,
  185. distribution_left, distribution_right,
  186. el, er, count_left, count_right, maxClassNo ) )
  187. continue;
  188. // information gain and entropy
  189. double pl = (count_left) / (count_left + count_right);
  190. double ig = e - pl*el - (1-pl)*er;
  191. if ( use_shannon_entropy )
  192. {
  193. double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
  194. ig = 2*ig / ( e + esplit );
  195. }
  196. if ( ig > best_ig )
  197. {
  198. best_ig = ig;
  199. best_threshold = threshold;
  200. for ( int k = 0 ; k <= maxClassNo ; k++ )
  201. {
  202. best_distribution_left[k] = distribution_left[k];
  203. best_distribution_right[k] = distribution_right[k];
  204. }
  205. best_entropy_left = el;
  206. best_entropy_right = er;
  207. }
  208. }
  209. //cleaning up
  210. delete [] distribution_left;
  211. delete [] distribution_right;
  212. // stop criteria: minimum information gain
  213. if ( best_ig < minimum_information_gain )
  214. {
  215. #ifdef DEBUGTREE
  216. std::cerr << "DTBOblique: Minimum information gain reached!" << std::endl;
  217. #endif
  218. delete [] best_distribution_left;
  219. delete [] best_distribution_right;
  220. node->trainExamplesIndices = examples_selection;
  221. return node;
  222. }
  223. /** Save the best split to current node */
  224. node->f = f->clone();
  225. node->threshold = best_threshold;
  226. /** Split examples according to split function */
  227. vector<int> examples_left;
  228. vector<int> examples_right;
  229. examples_left.reserve ( values.size() / 2 );
  230. examples_right.reserve ( values.size() / 2 );
  231. for ( FeatureValuesUnsorted::const_iterator i = values.begin();
  232. i != values.end(); i++ )
  233. {
  234. double value = i->first;
  235. if ( value < best_threshold )
  236. examples_left.push_back ( i->third );
  237. else
  238. examples_right.push_back ( i->third );
  239. }
  240. #ifdef DEBUGTREE
  241. node->f->store( std::cerr );
  242. std::cerr << std::endl;
  243. std::cerr << "mutual information / shannon entropy " << best_ig << " entropy "
  244. << e << " left entropy " << best_entropy_left << " right entropy "
  245. << best_entropy_right << std::endl;
  246. #endif
  247. FullVector distribution_left_sparse ( distribution.size() );
  248. FullVector distribution_right_sparse ( distribution.size() );
  249. for ( int k = 0 ; k <= maxClassNo ; k++ )
  250. {
  251. double l = best_distribution_left[k];
  252. double r = best_distribution_right[k];
  253. if ( l != 0 )
  254. distribution_left_sparse[k] = l;
  255. if ( r != 0 )
  256. distribution_right_sparse[k] = r;
  257. #ifdef DEBUGTREE
  258. if ( (l>0)||(r>0) )
  259. {
  260. std::cerr << "DTBOblique: split of class " << k << " ("
  261. << l << " <-> " << r << ") " << std::endl;
  262. }
  263. #endif
  264. }
  265. delete [] best_distribution_left;
  266. delete [] best_distribution_right;
  267. // update lambda by heuristic [Laptev/Buhmann, 2014]
  268. double lambdaLeft = lambdaCurrent *
  269. pow(((double)examples_selection.size()/(double)examples_left.size()),(2./f->getParameterLength()));
  270. double lambdaRight = lambdaCurrent *
  271. pow(((double)examples_selection.size()/(double)examples_right.size()),(2./f->getParameterLength()));
  272. #ifdef DEBUGTREE
  273. std::cerr << "regularization parameter lambda: left " << lambdaLeft
  274. << " right " << lambdaRight << std::endl;
  275. #endif
  276. /** Recursion */
  277. // left child
  278. node->left = buildRecursive ( fp, examples, examples_left,
  279. distribution_left_sparse, best_entropy_left,
  280. maxClassNo, depth+1, lambdaLeft );
  281. // right child
  282. node->right = buildRecursive ( fp, examples, examples_right,
  283. distribution_right_sparse, best_entropy_right,
  284. maxClassNo, depth+1, lambdaRight );
  285. return node;
  286. }
  287. /** initial building method */
  288. DecisionNode *DTBOblique::build ( const FeaturePool & fp,
  289. const Examples & examples,
  290. int maxClassNo )
  291. {
  292. int index = 0;
  293. FullVector distribution ( maxClassNo+1 );
  294. vector<int> all;
  295. all.reserve ( examples.size() );
  296. for ( Examples::const_iterator j = examples.begin();
  297. j != examples.end(); j++ )
  298. {
  299. int classno = j->first;
  300. distribution[classno] += j->second.weight;
  301. all.push_back ( index );
  302. index++;
  303. }
  304. double entropy = 0.0;
  305. double sum = 0.0;
  306. for ( int i = 0 ; i < distribution.size(); i++ )
  307. {
  308. double val = distribution[i];
  309. if ( val <= 0.0 ) continue;
  310. entropy -= val*log(val);
  311. sum += val;
  312. }
  313. entropy /= sum;
  314. entropy += log(sum);
  315. return buildRecursive ( fp, examples, all, distribution,
  316. entropy, maxClassNo, 0, lambdaInit );
  317. }