RegPreRandomForests.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * @file RegPreRandomForests.cpp
  3. * @brief Combination of a regression method with a pre-clustering using a random forest
  4. * @author Sven Sickert
  5. * @date 07/12/2013
  6. */
  7. #include "vislearning/regression/regcombination/RegPreRandomForests.h"
  8. #include <iostream>
  9. #include <assert.h>
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. RegPreRandomForests::RegPreRandomForests(const Config * conf,
  14. const string & section,
  15. RegressionAlgorithm *_leafRegressionPrototype )
  16. : leafRegressionPrototype(_leafRegressionPrototype)
  17. {
  18. string cluster_section = conf->gS ( section, "cluster_section", "RandomForest" );
  19. mEx = conf->gI ( "RTBRandom", "min_examples", 500 );
  20. randomforest = new RegRandomForests( conf, cluster_section );
  21. }
  22. RegPreRandomForests::~RegPreRandomForests()
  23. {
  24. // delete the random forest
  25. if ( randomforest != NULL )
  26. delete randomforest;
  27. // delete all regression methods in the leafs
  28. for ( map<RegressionNode *, RegressionAlgorithm * >::const_iterator it = leafRegressions.begin();
  29. it != leafRegressions.end(); it++ )
  30. {
  31. RegressionAlgorithm * lr = it->second;
  32. if ( lr != NULL )
  33. delete lr;
  34. }
  35. // delete regression prototype
  36. if ( leafRegressionPrototype != NULL )
  37. delete leafRegressionPrototype;
  38. }
  39. void RegPreRandomForests::teach ( const VVector & X, const Vector & y )
  40. {
  41. randomforest->teach ( X, y );
  42. if ( leafRegressionPrototype != NULL )
  43. {
  44. vector<RegressionNode *> leafNodes;
  45. randomforest->getAllLeafNodes ( leafNodes );
  46. int lsize = leafNodes.size();
  47. int leafNo = 0;
  48. cerr << "leafnodes: " << lsize << endl;
  49. #pragma omp parallel for
  50. for ( int l = 0; l < lsize; l++ )
  51. {
  52. leafNo++;
  53. RegressionNode *node = leafNodes[l];
  54. if ( !node->isLeaf() ){
  55. fprintf( stderr, "RegPreRandomForests::predict: ID #%d not a leaf node!", leafNo );
  56. continue;
  57. }
  58. vector<int> leafTrainInds = node->trainExamplesIndices;
  59. cerr << "Teaching regression method for leaf " << leafNo-1 << "..." << endl;
  60. cerr << "examples in leave: " << leafTrainInds.size() << endl;
  61. assert ( leafTrainInds.size() > 0 );
  62. sort ( leafTrainInds.begin(), leafTrainInds.end() );
  63. NICE::VVector leafTrainData;
  64. vector<double> tmpVals;
  65. for ( int i = 0; i < (int)leafTrainInds.size(); i++ )
  66. {
  67. if ( leafTrainInds[i] >= 0 && leafTrainInds[i] < (int)y.size() )
  68. {
  69. leafTrainData.push_back( X[ leafTrainInds[i] ] );
  70. tmpVals.push_back( y[ leafTrainInds[i] ] );
  71. }
  72. }
  73. if (leafTrainData.size() <= 0 ) continue;
  74. NICE::Vector leafTrainVals( tmpVals );
  75. RegressionAlgorithm *lr = leafRegressionPrototype->clone();
  76. lr->teach( leafTrainData, leafTrainVals );
  77. leafRegressions.insert ( pair< RegressionNode *, RegressionAlgorithm *> ( node, lr ) );
  78. }
  79. }
  80. }
  81. double RegPreRandomForests::predict ( const Vector & x )
  82. {
  83. double pred = 0.0;
  84. vector<RegressionNode *> leafNodes;
  85. // traverse the forest and obtain all innvolved leaf nodes
  86. randomforest->getLeafNodes ( x, leafNodes );
  87. for ( vector<RegressionNode *>::const_iterator it = leafNodes.begin();
  88. it != leafNodes.end(); it++ )
  89. {
  90. RegressionNode *node = *it;
  91. map<RegressionNode *, RegressionAlgorithm *>::const_iterator leafRegressionIt =
  92. leafRegressions.find( node );
  93. if ( leafRegressionIt == leafRegressions.end() )
  94. {
  95. // this leaf has no associated regression method
  96. // -> we will use the random forest result
  97. pred += node->predVal;
  98. continue;
  99. }
  100. RegressionAlgorithm *leafRegression = leafRegressionIt->second;
  101. pred += leafRegression->predict( x );
  102. }
  103. pred /= leafNodes.size();
  104. return pred;
  105. }
  106. void RegPreRandomForests::clear ()
  107. {
  108. map<RegressionNode *, RegressionAlgorithm *>::iterator iter;
  109. for ( iter = leafRegressions.begin(); iter != leafRegressions.end(); iter++ )
  110. {
  111. iter->second->clear();
  112. }
  113. randomforest->clear();
  114. }
  115. void RegPreRandomForests::store ( ostream & os, int format ) const
  116. {
  117. cerr << "RegPreRandomForest::store: not yet implemented" << endl;
  118. }
  119. void RegPreRandomForests::restore ( istream& is, int format )
  120. {
  121. cerr << "RegPreRandomForest::restore: not yet implemented" << endl;
  122. }