toyExampleStoreRestore.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /**
  2. * @file toyExampleStoreRestore.cpp
  3. * @brief
  4. * @author Alexander Freytag
  5. * @date 21-12-2013
  6. */
  7. // STL includes
  8. #include <iostream>
  9. #include <vector>
  10. // NICE-core includes
  11. #include <core/basics/Config.h>
  12. #include <core/basics/Timer.h>
  13. // gp-hik-core includes
  14. #include "gp-hik-core/GPHIKClassifier.h"
  15. using namespace std; //C basics
  16. using namespace NICE; // nice-core
  17. int main (int argc, char* argv[])
  18. {
  19. NICE::Config conf ( argc, argv );
  20. std::string trainData = conf.gS( "main", "trainData", "progs/toyExampleSmallScaleTrain.data" );
  21. NICE::GPHIKClassifier * classifier;
  22. //------------- read the training data --------------
  23. NICE::Matrix dataTrain;
  24. NICE::Vector yBinTrain;
  25. NICE::Vector yMultiTrain;
  26. std::ifstream ifsTrain ( trainData.c_str() , ios::in );
  27. if (ifsTrain.good() )
  28. {
  29. ifsTrain >> dataTrain;
  30. ifsTrain >> yBinTrain;
  31. ifsTrain >> yMultiTrain;
  32. ifsTrain.close();
  33. }
  34. else
  35. {
  36. std::cerr << "Unable to read training data, aborting." << std::endl;
  37. return -1;
  38. }
  39. //----------------- convert data to sparse data structures ---------
  40. std::vector< NICE::SparseVector *> examplesTrain;
  41. examplesTrain.resize( dataTrain.rows() );
  42. std::vector< NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  43. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  44. {
  45. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  46. }
  47. // TRAIN CLASSIFIER FROM SCRATCH
  48. classifier = new GPHIKClassifier ( &conf );
  49. classifier->train ( examplesTrain , yMultiTrain );
  50. // TEST STORING ABILITIES
  51. std::string s_destination_save ( "/home/alex/code/nice/gp-hik-core/progs/myClassifier.txt" );
  52. std::filebuf fbOut;
  53. fbOut.open ( s_destination_save.c_str(), ios::out );
  54. std::ostream os (&fbOut);
  55. //
  56. classifier->store( os );
  57. //
  58. fbOut.close();
  59. // TEST RESTORING ABILITIES
  60. NICE::GPHIKClassifier * classifierRestored = new GPHIKClassifier;
  61. std::string s_destination_load ( "/home/alex/code/nice/gp-hik-core/progs/myClassifier.txt" );
  62. std::filebuf fbIn;
  63. fbIn.open ( s_destination_load.c_str(), ios::in );
  64. std::istream is (&fbIn);
  65. //
  66. classifierRestored->restore( is );
  67. //
  68. fbIn.close();
  69. // TEST both classifiers to produce equal results
  70. //------------- read the test data --------------
  71. NICE::Matrix dataTest;
  72. NICE::Vector yBinTest;
  73. NICE::Vector yMultiTest;
  74. std::string testData = conf.gS( "main", "testData", "progs/toyExampleTest.data" );
  75. std::ifstream ifsTest ( testData.c_str(), ios::in );
  76. if (ifsTest.good() )
  77. {
  78. ifsTest >> dataTest;
  79. ifsTest >> yBinTest;
  80. ifsTest >> yMultiTest;
  81. ifsTest.close();
  82. }
  83. else
  84. {
  85. std::cerr << "Unable to read test data, aborting." << std::endl;
  86. return -1;
  87. }
  88. // ------------------------------------------
  89. // ------------- PREPARATION --------------
  90. // ------------------------------------------
  91. // determine classes known during training and corresponding mapping
  92. // thereby allow for non-continous class labels
  93. std::set<int> classesKnownTraining = classifier->getKnownClassNumbers();
  94. int noClassesKnownTraining ( classesKnownTraining.size() );
  95. std::map<int,int> mapClNoToIdxTrain;
  96. std::set<int>::const_iterator clTrIt = classesKnownTraining.begin();
  97. for ( int i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  98. mapClNoToIdxTrain.insert ( std::pair<int,int> ( *clTrIt, i ) );
  99. // determine classes known during testing and corresponding mapping
  100. // thereby allow for non-continous class labels
  101. std::set<int> classesKnownTest;
  102. classesKnownTest.clear();
  103. // determine which classes we have in our label vector
  104. // -> MATLAB: myClasses = unique(y);
  105. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  106. {
  107. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  108. {
  109. classesKnownTest.insert ( *it );
  110. }
  111. }
  112. int noClassesKnownTest ( classesKnownTest.size() );
  113. std::map<int,int> mapClNoToIdxTest;
  114. std::set<int>::const_iterator clTestIt = classesKnownTest.begin();
  115. for ( int i=0; i < noClassesKnownTest; i++, clTestIt++ )
  116. mapClNoToIdxTest.insert ( std::pair<int,int> ( *clTestIt, i ) );
  117. NICE::Matrix confusionMatrix ( noClassesKnownTraining, noClassesKnownTest, 0.0);
  118. NICE::Matrix confusionMatrixRestored ( noClassesKnownTraining, noClassesKnownTest, 0.0);
  119. NICE::Timer t;
  120. double testTime (0.0);
  121. double uncertainty;
  122. int i_loopEnd ( (int)dataTest.rows() );
  123. for (int i = 0; i < i_loopEnd ; i++)
  124. {
  125. NICE::Vector example ( dataTest.getRow(i) );
  126. NICE::SparseVector scores;
  127. int result;
  128. // classify with trained classifier
  129. t.start();
  130. classifier->classify( &example, result, scores );
  131. t.stop();
  132. testTime += t.getLast();
  133. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  134. // classify with restored classifier
  135. t.start();
  136. classifierRestored->classify( &example, result, scores );
  137. t.stop();
  138. testTime += t.getLast();
  139. confusionMatrixRestored( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  140. }
  141. confusionMatrix.normalizeColumnsL1();
  142. std::cerr << confusionMatrix << std::endl;
  143. std::cerr << "average recognition rate: " << confusionMatrix.trace()/confusionMatrix.cols() << std::endl;
  144. confusionMatrixRestored.normalizeColumnsL1();
  145. std::cerr << confusionMatrixRestored << std::endl;
  146. std::cerr << "average recognition rate of restored classifier: " << confusionMatrixRestored.trace()/confusionMatrixRestored.cols() << std::endl;
  147. return 0;
  148. }