DTBObliqueLS.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. /**
  2. * @file DTBObliqueLS.cpp
  3. * @brief random oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #include <iostream>
  8. #include <time.h>
  9. #include "DTBObliqueLS.h"
  10. #include "SCInformationGain.h"
  11. #include "SCGiniIndex.h"
  12. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  13. #include "core/vector/Algorithms.h"
  14. using namespace OBJREC;
  15. //#define DEBUGTREE
  16. DTBObliqueLS::DTBObliqueLS ( const NICE::Config *conf, std::string section )
  17. {
  18. saveIndices = conf->gB( section, "save_indices", false);
  19. useDynamicRegularization = conf->gB( section, "use_dynamic_regularization", true );
  20. multiClassMode = conf->gB( section, "multi_class_mode", 0 );
  21. splitSteps = conf->gI( section, "split_steps", 20 );
  22. maxDepth = conf->gI( section, "max_depth", 10 );
  23. regularizationType = conf->gI( section, "regularization_type", 1 );
  24. lambdaInit = conf->gD( section, "lambda_init", 0.5 );
  25. std::string splitCrit = conf->gS( section, "split_criterion", "information_gain" );
  26. if (splitCrit == "information_gain")
  27. splitCriterion = new SCInformationGain( conf );
  28. else if (splitCrit == "gini_index")
  29. splitCriterion = new SCGiniIndex( conf );
  30. else
  31. {
  32. std::cerr << "DTBObliqueLS::DTBObliqueLS: No valid splitting criterion defined!" << std::endl;
  33. splitCriterion = NULL;
  34. }
  35. if ( conf->gB(section, "start_random_generator", true ) )
  36. srand(time(NULL));
  37. }
  38. DTBObliqueLS::~DTBObliqueLS()
  39. {
  40. if (splitCriterion != NULL)
  41. delete splitCriterion;
  42. }
  43. bool DTBObliqueLS::adaptDataAndLabelForMultiClass (
  44. const int posClass,
  45. const int negClass,
  46. NICE::Matrix & X,
  47. NICE::Vector & y )
  48. {
  49. int posCount = 0;
  50. int negCount = 0;
  51. // One-vs-one: Transforming into {-1,0,+1} problem
  52. if ( multiClassMode == 0 )
  53. for ( int i = 0; i < y.size(); i++ )
  54. {
  55. if ( y[i] == posClass )
  56. {
  57. y[i] = 1.0;
  58. posCount++;
  59. }
  60. else if ( y[i] == negClass )
  61. {
  62. y[i] = -1.0;
  63. negCount++;
  64. }
  65. else
  66. {
  67. y[i] = 0.0;
  68. X.setRow( i, NICE::Vector( X.cols(), 0.0 ) );
  69. }
  70. }
  71. // One-vs-all: Transforming into {-1,+1} problem
  72. else if ( multiClassMode == 1 )
  73. for ( int i = 0; i < y.size(); i++ )
  74. {
  75. if ( y[i] == posClass )
  76. {
  77. y[i] = 1.0;
  78. posCount++;
  79. }
  80. else
  81. {
  82. y[i] = -1.0;
  83. negCount++;
  84. }
  85. }
  86. // Many-vs-many: Transforming into {-1,+1}
  87. else
  88. {
  89. // get existing classes
  90. std::vector<double> unClass = y.std_vector();
  91. std::sort( unClass.begin(), unClass.end() );
  92. unClass.erase( std::unique( unClass.begin(), unClass.end() ), unClass.end() );
  93. // randomly split set of classes into two buckets
  94. std::random_shuffle ( unClass.begin(), unClass.end() );
  95. int firstHalf = std::ceil(unClass.size()/2.0);
  96. for ( int i = 0; i < y.size(); i++ )
  97. {
  98. bool wasFound = false;
  99. int c = 0;
  100. //assign new labels
  101. while ( (!wasFound) && (c<firstHalf) )
  102. {
  103. if ( y[i] == unClass[c] )
  104. {
  105. wasFound = true;
  106. }
  107. c++;
  108. }
  109. if (wasFound)
  110. {
  111. y[i] = 1.0;
  112. posCount++;
  113. }
  114. else
  115. {
  116. y[i] = -1.0;
  117. negCount++;
  118. }
  119. }
  120. }
  121. return ( (posCount>0) && (negCount>0));
  122. }
  123. /** refresh data matrix X and label vector y */
  124. void DTBObliqueLS::getDataAndLabel(
  125. const FeaturePool &fp,
  126. const Examples &examples,
  127. const std::vector<int> &examples_selection,
  128. NICE::Matrix & X,
  129. NICE::Vector & y,
  130. NICE::Vector & w )
  131. {
  132. ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
  133. int amountParams = f->getParameterLength();
  134. int amountExamples = examples_selection.size();
  135. X = NICE::Matrix(amountExamples, amountParams, 0.0 );
  136. y = NICE::Vector(amountExamples, 0.0);
  137. w = NICE::Vector(amountExamples, 1.0);
  138. int matIndex = 0;
  139. for ( std::vector<int>::const_iterator si = examples_selection.begin();
  140. si != examples_selection.end();
  141. si++ )
  142. {
  143. const std::pair<int, Example> & p = examples[*si];
  144. const Example & ex = p.second;
  145. NICE::Vector pixelRepr (amountParams, 1.0);
  146. f->getFeatureVector( &ex, pixelRepr );
  147. double label = p.first;
  148. pixelRepr *= ex.weight;
  149. w.set ( matIndex, ex.weight );
  150. y.set ( matIndex, label );
  151. X.setRow ( matIndex, pixelRepr );
  152. matIndex++;
  153. }
  154. }
  155. void DTBObliqueLS::regularizeDataMatrix(
  156. const NICE::Matrix &X,
  157. NICE::Matrix &XTXreg,
  158. const int regOption,
  159. const double lambda )
  160. {
  161. XTXreg = X.transpose()*X;
  162. NICE::Matrix R;
  163. const int dim = X.cols();
  164. switch (regOption)
  165. {
  166. // identity matrix
  167. case 0:
  168. R.resize(dim,dim);
  169. R.setIdentity();
  170. R *= lambda;
  171. XTXreg += R;
  172. break;
  173. // differences operator, k=1
  174. case 1:
  175. R.resize(dim-1,dim);
  176. R.set( 0.0 );
  177. for ( int r = 0; r < dim-1; r++ )
  178. {
  179. R(r,r) = 1.0;
  180. R(r,r+1) = -1.0;
  181. }
  182. R = R.transpose()*R;
  183. R *= lambda;
  184. XTXreg += R;
  185. break;
  186. // difference operator, k=2
  187. case 2:
  188. R.resize(dim-2,dim);
  189. R.set( 0.0 );
  190. for ( int r = 0; r < dim-2; r++ )
  191. {
  192. R(r,r) = 1.0;
  193. R(r,r+1) = -2.0;
  194. R(r,r+2) = 1.0;
  195. }
  196. R = R.transpose()*R;
  197. R *= lambda;
  198. XTXreg += R;
  199. break;
  200. // as in [Chen et al., 2012]
  201. case 3:
  202. {
  203. NICE::Vector q ( dim, (1.0-lambda) );
  204. q[0] = 1.0;
  205. NICE::Matrix Q;
  206. Q.tensorProduct(q,q);
  207. R.resize(dim,dim);
  208. for ( int r = 0; r < dim; r++ )
  209. {
  210. for ( int c = 0; c < dim; c++ )
  211. R(r,c) = XTXreg(r,c) * Q(r,c);
  212. R(r,r) = q[r] * XTXreg(r,r);
  213. }
  214. XTXreg = R;
  215. break;
  216. }
  217. // no regularization
  218. default:
  219. std::cerr << "DTBObliqueLS::regularizeDataMatrix: No regularization applied!"
  220. << std::endl;
  221. break;
  222. }
  223. }
  224. void DTBObliqueLS::findBestSplitThreshold (
  225. FeatureValuesUnsorted &values,
  226. SplitInfo &bestSplitInfo,
  227. const NICE::Vector &params,
  228. const int &maxClassNo )
  229. {
  230. double *distribution_left = new double [maxClassNo+1];
  231. double *distribution_right = new double [maxClassNo+1];
  232. double minValue = (min_element ( values.begin(), values.end() ))->first;
  233. double maxValue = (max_element ( values.begin(), values.end() ))->first;
  234. if ( maxValue - minValue < 1e-7 )
  235. std::cerr << "DTBObliqueLS: Difference between min and max of features values to small!"
  236. << " [" << minValue << "," << maxValue << "]" << std::endl;
  237. // get best thresholds using complete search
  238. for ( int i = 0; i < splitSteps; i++ )
  239. {
  240. double threshold = (i * (maxValue - minValue ) / (double)splitSteps)
  241. + minValue;
  242. // preparations
  243. for ( int k = 0 ; k <= maxClassNo ; k++ )
  244. {
  245. distribution_left[k] = 0.0;
  246. distribution_right[k] = 0.0;
  247. }
  248. /** Test the current split */
  249. SplittingCriterion *curSplit = splitCriterion->clone();
  250. if ( ! curSplit->evaluateSplit ( values, threshold,
  251. distribution_left, distribution_right, maxClassNo ) )
  252. continue;
  253. // get value for impurity
  254. double purity = curSplit->computePurity();
  255. double entropy = curSplit->getEntropy();
  256. if ( purity > bestSplitInfo.purity )
  257. {
  258. bestSplitInfo.purity = purity;
  259. bestSplitInfo.entropy = entropy;
  260. bestSplitInfo.threshold = threshold;
  261. bestSplitInfo.params = params;
  262. for ( int k = 0 ; k <= maxClassNo ; k++ )
  263. {
  264. bestSplitInfo.distLeft[k] = distribution_left[k];
  265. bestSplitInfo.distRight[k] = distribution_right[k];
  266. }
  267. }
  268. delete curSplit;
  269. }
  270. //cleaning up
  271. delete [] distribution_left;
  272. delete [] distribution_right;
  273. }
  274. /** recursive building method */
  275. DecisionNode *DTBObliqueLS::buildRecursive(
  276. const FeaturePool & fp,
  277. const Examples & examples,
  278. std::vector<int> & examples_selection,
  279. FullVector & distribution,
  280. double entropy,
  281. int maxClassNo,
  282. int depth,
  283. double lambdaCurrent )
  284. {
  285. std::cerr << "DTBObliqueLS: Examples: " << (int)examples_selection.size()
  286. << ", Depth: " << (int)depth << ", Entropy: " << entropy << std::endl;
  287. // initialize new node
  288. DecisionNode *node = new DecisionNode ();
  289. node->distribution = distribution;
  290. // stopping criteria
  291. if ( ( entropy <= splitCriterion->getMinimumEntropy() )
  292. || ( (int)examples_selection.size() < splitCriterion->getMinimumExamples() )
  293. || ( depth > maxDepth ) )
  294. {
  295. #ifdef DEBUGTREE
  296. std::cerr << "DTBObliqueLS: Stopping criteria applied!" << std::endl;
  297. #endif
  298. node->trainExamplesIndices = examples_selection;
  299. return node;
  300. }
  301. // variables
  302. FeatureValuesUnsorted values;
  303. SplitInfo bestSplitInfo;
  304. bestSplitInfo.threshold = 0.0;
  305. bestSplitInfo.purity = -1.0;
  306. bestSplitInfo.entropy = 0.0;
  307. bestSplitInfo.distLeft = new double [maxClassNo+1];
  308. bestSplitInfo.distRight = new double [maxClassNo+1];
  309. ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
  310. bestSplitInfo.params = f->getParameterVector();
  311. // Creating data matrix X and label vector y
  312. NICE::Matrix X;
  313. NICE::Vector y, params, weights;
  314. getDataAndLabel( fp, examples, examples_selection, X, y, weights );
  315. // Transforming into multi-class problem
  316. bool hasExamples = false;
  317. NICE::Vector yCur;
  318. NICE::Matrix XCur;
  319. while ( !hasExamples )
  320. {
  321. int posClass, negClass;
  322. posClass = rand() % (maxClassNo+1);
  323. negClass = (posClass + (rand() % maxClassNo)) % (maxClassNo+1);
  324. yCur = y;
  325. XCur = X;
  326. hasExamples = adaptDataAndLabelForMultiClass(
  327. posClass, negClass, XCur, yCur );
  328. }
  329. yCur *= weights;
  330. // Preparing system of linear equations
  331. NICE::Matrix XTXr, G, temp;
  332. regularizeDataMatrix( XCur, XTXr, regularizationType, lambdaCurrent );
  333. choleskyDecomp(XTXr, G);
  334. choleskyInvert(G, XTXr);
  335. temp = XTXr * XCur.transpose();
  336. // Solve system of linear equations in a least squares manner
  337. params.multiply(temp,yCur,false);
  338. // Updating parameter vector in convolutional feature
  339. f->setParameterVector( params );
  340. // Feature Values
  341. values.clear();
  342. f->calcFeatureValues( examples, examples_selection, values);
  343. // complete search for threshold
  344. findBestSplitThreshold ( values, bestSplitInfo, params, maxClassNo );
  345. // stop criteria: minimum purity reached?
  346. if ( bestSplitInfo.purity < splitCriterion->getMinimumPurity() )
  347. {
  348. #ifdef DEBUGTREE
  349. std::cerr << "DTBObliqueLS: Minimum purity reached!" << std::endl;
  350. #endif
  351. delete [] bestSplitInfo.distLeft;
  352. delete [] bestSplitInfo.distRight;
  353. node->trainExamplesIndices = examples_selection;
  354. return node;
  355. }
  356. /** Save the best split to current node */
  357. f->setParameterVector( bestSplitInfo.params );
  358. values.clear();
  359. f->calcFeatureValues( examples, examples_selection, values);
  360. node->f = f->clone();
  361. node->threshold = bestSplitInfo.threshold;
  362. /** Split examples according to best split function */
  363. std::vector<int> examples_left;
  364. std::vector<int> examples_right;
  365. examples_left.reserve ( values.size() / 2 );
  366. examples_right.reserve ( values.size() / 2 );
  367. for ( FeatureValuesUnsorted::const_iterator i = values.begin();
  368. i != values.end(); i++ )
  369. {
  370. if ( i->first < bestSplitInfo.threshold )
  371. examples_left.push_back ( i->third );
  372. else
  373. examples_right.push_back ( i->third );
  374. }
  375. #ifdef DEBUGTREE
  376. // node->f->store( std::cerr );
  377. // std::cerr << std::endl;
  378. #endif
  379. FullVector distribution_left_sparse ( distribution.size() );
  380. FullVector distribution_right_sparse ( distribution.size() );
  381. for ( int k = 0 ; k <= maxClassNo ; k++ )
  382. {
  383. double l = bestSplitInfo.distLeft[k];
  384. double r = bestSplitInfo.distRight[k];
  385. if ( l != 0 )
  386. distribution_left_sparse[k] = l;
  387. if ( r != 0 )
  388. distribution_right_sparse[k] = r;
  389. #ifdef DEBUGTREE
  390. std::cerr << "DTBObliqueLS: Split of Class " << k << " ("
  391. << l << " <-> " << r << ") " << std::endl;
  392. #endif
  393. }
  394. delete [] bestSplitInfo.distLeft;
  395. delete [] bestSplitInfo.distRight;
  396. // update lambda by heuristic [Laptev/Buhmann, 2014]
  397. double lambdaLeft, lambdaRight;
  398. if (useDynamicRegularization)
  399. {
  400. lambdaLeft = lambdaCurrent *
  401. pow(((double)examples_selection.size()/(double)examples_left.size()),(2./f->getParameterLength()));
  402. lambdaRight = lambdaCurrent *
  403. pow(((double)examples_selection.size()/(double)examples_right.size()),(2./f->getParameterLength()));
  404. }
  405. else
  406. {
  407. lambdaLeft = lambdaCurrent;
  408. lambdaRight = lambdaCurrent;
  409. }
  410. /** Recursion */
  411. // left child
  412. node->left = buildRecursive ( fp, examples, examples_left,
  413. distribution_left_sparse, bestSplitInfo.entropy,
  414. maxClassNo, depth+1, lambdaLeft );
  415. // right child
  416. node->right = buildRecursive ( fp, examples, examples_right,
  417. distribution_right_sparse, bestSplitInfo.entropy,
  418. maxClassNo, depth+1, lambdaRight );
  419. return node;
  420. }
  421. /** initial building method */
  422. DecisionNode *DTBObliqueLS::build ( const FeaturePool & fp,
  423. const Examples & examples,
  424. int maxClassNo )
  425. {
  426. int index = 0;
  427. FullVector distribution ( maxClassNo+1 );
  428. std::vector<int> all;
  429. all.reserve ( examples.size() );
  430. for ( Examples::const_iterator j = examples.begin();
  431. j != examples.end(); j++ )
  432. {
  433. int classno = j->first;
  434. distribution[classno] += j->second.weight;
  435. all.push_back ( index );
  436. index++;
  437. }
  438. double entropy = 0.0;
  439. double sum = 0.0;
  440. for ( int i = 0 ; i < distribution.size(); i++ )
  441. {
  442. double val = distribution[i];
  443. if ( val <= 0.0 ) continue;
  444. entropy -= val*log(val);
  445. sum += val;
  446. }
  447. entropy /= sum;
  448. entropy += log(sum);
  449. return buildRecursive ( fp, examples, all, distribution,
  450. entropy, maxClassNo, 0, lambdaInit );
  451. }