TestCodebookRandomForest.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. /**
  2. * Unit test for Extremely randomized clustering forest (ERC).
  3. *
  4. * @author Johannes Ruehle
  5. * @date 01/05/2014
  6. */
  7. #ifdef NICE_USELIB_CPPUNIT
  8. #include <string>
  9. #include <exception>
  10. #include <iostream>
  11. #include <fstream>
  12. //----------
  13. #include "TestCodebookRandomForest.h"
  14. #include "vislearning/features/simplefeatures/CodebookRandomForest.h"
  15. #include "vislearning/features/fpfeatures/VectorFeature.h"
  16. #include "vislearning/cbaselib/FeaturePool.h"
  17. const bool verbose = false;
  18. const bool verboseStartEnd = true;
  19. using namespace OBJREC;
  20. using namespace NICE;
  21. using namespace std;
  22. CPPUNIT_TEST_SUITE_REGISTRATION( TestCodebookRandomForest );
  23. void TestCodebookRandomForest::setUp() {
  24. }
  25. void TestCodebookRandomForest::tearDown() {
  26. }
  27. void TestCodebookRandomForest::testCodebookRandomForest()
  28. {
  29. if (verboseStartEnd)
  30. std::cerr << "================== TestCodebookRandomForest::TestCodebookRandomForest ===================== " << std::endl;
  31. try
  32. {
  33. Matrix mX;
  34. Vector vY;
  35. Vector vY_multi;
  36. //ifstream ifs ("toyExample1.data", ios::in);
  37. // ifstream ifs ("toyExampleLargeScale.data", ios::in);
  38. ifstream ifs ("toyExampleLargeLargeScale.data", ios::in);
  39. CPPUNIT_ASSERT ( ifs.good() );
  40. ifs >> mX;
  41. ifs >> vY;
  42. ifs >> vY_multi;
  43. ifs.close();
  44. if (verbose)
  45. {
  46. std::cerr << "data loaded: mX" << std::endl;
  47. std::cerr << mX << std::endl;
  48. std::cerr << "vY: " << std::endl;
  49. std::cerr << vY << std::endl;
  50. std::cerr << "vY_multi: " << std::endl;
  51. std::cerr << vY_multi << std::endl;
  52. }
  53. int iNumFeatureDimension = mX.cols();
  54. // memory layout needs to be transposed into rows x column: features x samples
  55. // features must lay next to each other in memory, so that each feature vector can
  56. // be adressed by a starting pointer and the number of feature dimensions to come.
  57. Matrix mX_transposed = mX.transpose();
  58. Examples examples;
  59. bool bSuccess = Examples::wrapExamplesAroundFeatureMatrix(mX_transposed, vY_multi, examples);
  60. CPPUNIT_ASSERT( bSuccess );
  61. CPPUNIT_ASSERT( examples.size() == mX.rows() );
  62. //----------------- create raw feature mapping -------------
  63. OBJREC::FeaturePool fp;
  64. OBJREC::VectorFeature *pVecFeature = new OBJREC::VectorFeature(iNumFeatureDimension);
  65. pVecFeature->explode(fp);
  66. //----------------- debug features -------------
  67. OBJREC::Example t_Exp = examples[0].second;
  68. NICE::Vector t_FeatVector;
  69. fp.calcFeatureVector(t_Exp, t_FeatVector);
  70. std::cerr << "first full Feature Vec: " <<t_FeatVector << std::endl;
  71. //----------------- train our random Forest -------------
  72. NICE::Config conf("config.conf");
  73. OBJREC::FPCRandomForests *pRandForest = new OBJREC::FPCRandomForests(&conf,"RandomForest");
  74. pRandForest->train(fp, examples);
  75. //----------------- create codebook ERC clusterer -------------
  76. int nMaxDepth = conf.gI("CodebookRandomForest", "maxDepthTree",10);
  77. int nMaxCodebookSize = conf.gI("CodebookRandomForest", "maxCodebookSize",100);
  78. std::cerr << "maxDepthTree " << nMaxDepth << std::endl;
  79. OBJREC::CodebookRandomForest *pCodebookRandomForest = new OBJREC::CodebookRandomForest(pRandForest, nMaxDepth, nMaxCodebookSize);
  80. //----------------- quantize samples into histogram -------------
  81. size_t iNumCodewords = pCodebookRandomForest->getCodebookSize();
  82. NICE::Vector histogram(iNumCodewords, 0.0f);
  83. int t_iCodebookEntry; double t_fWeight; double t_fDistance;
  84. for (size_t i = 0; i < examples.size(); i++ )
  85. {
  86. Example &t_Ex = examples[i].second;
  87. pCodebookRandomForest->voteVQ( *t_Ex.vec, histogram, t_iCodebookEntry, t_fWeight, t_fDistance );
  88. std::cerr << i << ": " << "CBEntry " << t_iCodebookEntry << " Weight: " << t_fWeight << " Distance: " << t_fDistance << std::endl;
  89. }
  90. std::cerr << "histogram: " << histogram << std::endl;
  91. // test of store and restore
  92. std::string t_sDestinationSave = "codebookRF.save.txt";
  93. std::ofstream ofs;
  94. ofs.open (t_sDestinationSave.c_str(), std::ofstream::out);
  95. pCodebookRandomForest->store( ofs );
  96. ofs.close();
  97. // restore
  98. OBJREC::CodebookRandomForest *pTestCRF = new OBJREC::CodebookRandomForest(-1, -1);
  99. std::ifstream ifs2;
  100. ifs2.open (t_sDestinationSave.c_str() );
  101. pTestCRF->restore( ifs2 );
  102. ifs2.close();
  103. CPPUNIT_ASSERT_EQUAL(iNumCodewords, pTestCRF->getCodebookSize() );
  104. CPPUNIT_ASSERT_EQUAL(nMaxDepth, pTestCRF->getMaxDepth() );
  105. CPPUNIT_ASSERT_EQUAL(nMaxCodebookSize, pTestCRF->getRestrictedCodebookSize() );
  106. NICE::Vector histogramCompare(iNumCodewords, 0.0f);
  107. for (size_t i = 0; i < examples.size(); i++ )
  108. {
  109. Example &t_Ex = examples[i].second;
  110. pTestCRF->voteVQ( *t_Ex.vec, histogramCompare, t_iCodebookEntry, t_fWeight, t_fDistance );
  111. }
  112. std::cerr << "histogram of restored CodebookRandomForest: " << histogramCompare << std::endl;
  113. std::cerr << "comparing histograms...";
  114. for (size_t i = 0; i < iNumCodewords; i++ )
  115. {
  116. CPPUNIT_ASSERT_DOUBLES_EQUAL(histogram[i], histogramCompare[i], 1e-5 );
  117. }
  118. std::cerr << "equal..." << std::endl;
  119. // clean up
  120. delete pTestCRF;
  121. delete pCodebookRandomForest;
  122. examples.clean();
  123. delete pVecFeature;
  124. if (verboseStartEnd)
  125. std::cerr << "================== TestCodebookRandomForest::TestCodebookRandomForest done ===================== " << std::endl;
  126. }
  127. catch(std::exception &e)
  128. {
  129. std::cerr << "exception occured: " << e.what() << std::endl;
  130. }
  131. }
  132. #endif