toyExample.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /**
  2. * @file toyExample.cpp
  3. * @brief Demo-Program to show how to call some methods of the GPHIKClassifier class
  4. * @author Alexander Freytag
  5. * @date 19-10-2012
  6. */
  7. #include <iostream>
  8. #include <vector>
  9. #include <core/basics/Config.h>
  10. #include <core/basics/Timer.h>
  11. #include <core/vector/MatrixT.h>
  12. #include <core/vector/VectorT.h>
  13. #include "gp-hik-core/GPHIKClassifier.h"
  14. using namespace std; //C basics
  15. using namespace NICE; // nice-core
  16. int main (int argc, char* argv[])
  17. {
  18. Config conf ( argc, argv );
  19. std::string trainData = conf.gS( "main", "trainData", "progs/toyExampleSmallScaleTrain.data" );
  20. bool b_debug = conf.gB( "main", "debug", false );
  21. //------------- read the training data --------------
  22. NICE::Matrix dataTrain;
  23. NICE::Vector yBinTrain;
  24. NICE::Vector yMultiTrain;
  25. if ( b_debug )
  26. {
  27. dataTrain.resize(6,3);
  28. dataTrain.set(0);
  29. dataTrain(0,0) = 0.2; dataTrain(0,1) = 0.3; dataTrain(0,2) = 0.5;
  30. dataTrain(1,0) = 0.3; dataTrain(1,1) = 0.2; dataTrain(1,2) = 0.5;
  31. dataTrain(2,0) = 0.9; dataTrain(2,1) = 0.0; dataTrain(2,2) = 0.1;
  32. dataTrain(3,0) = 0.8; dataTrain(3,1) = 0.1; dataTrain(3,2) = 0.1;
  33. dataTrain(4,0) = 0.1; dataTrain(4,1) = 0.1; dataTrain(4,2) = 0.8;
  34. dataTrain(5,0) = 0.1; dataTrain(5,1) = 0.0; dataTrain(5,2) = 0.9;
  35. yMultiTrain.resize(6);
  36. yMultiTrain[0] = 1; yMultiTrain[1] = 1;
  37. yMultiTrain[2] = 2; yMultiTrain[3] = 2;
  38. yMultiTrain[2] = 3; yMultiTrain[3] = 3;
  39. }
  40. else
  41. {
  42. std::ifstream ifsTrain ( trainData.c_str() , ios::in );
  43. if (ifsTrain.good() )
  44. {
  45. ifsTrain >> dataTrain;
  46. ifsTrain >> yBinTrain;
  47. ifsTrain >> yMultiTrain;
  48. ifsTrain.close();
  49. }
  50. else
  51. {
  52. std::cerr << "Unable to read training data, aborting." << std::endl;
  53. return -1;
  54. }
  55. }
  56. //----------------- convert data to sparse data structures ---------
  57. std::vector< NICE::SparseVector *> examplesTrain;
  58. examplesTrain.resize( dataTrain.rows() );
  59. std::vector< NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  60. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  61. {
  62. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  63. }
  64. std::cerr << "Number of training examples: " << examplesTrain.size() << std::endl;
  65. //----------------- train our classifier -------------
  66. conf.sB("GPHIKClassifier", "verbose", false);
  67. GPHIKClassifier * classifier = new GPHIKClassifier ( &conf );
  68. classifier->train ( examplesTrain , yMultiTrain );
  69. // ------------------------------------------
  70. // ------------- CLASSIFICATION --------------
  71. // ------------------------------------------
  72. //------------- read the test data --------------
  73. NICE::Matrix dataTest;
  74. NICE::Vector yBinTest;
  75. NICE::Vector yMultiTest;
  76. if ( b_debug )
  77. {
  78. dataTest.resize(1,3);
  79. dataTest.set(0);
  80. dataTest(0,0) = 0.3; dataTest(0,1) = 0.4; dataTest(0,2) = 0.3;
  81. yMultiTrain.resize(1);
  82. yMultiTrain[0] = 1;
  83. }
  84. else
  85. {
  86. std::string testData = conf.gS( "main", "testData", "progs/toyExampleTest.data" );
  87. std::ifstream ifsTest ( testData.c_str(), ios::in );
  88. if (ifsTest.good() )
  89. {
  90. ifsTest >> dataTest;
  91. ifsTest >> yBinTest;
  92. ifsTest >> yMultiTest;
  93. ifsTest.close();
  94. }
  95. else
  96. {
  97. std::cerr << "Unable to read test data, aborting." << std::endl;
  98. return -1;
  99. }
  100. }
  101. //TODO adapt this to the actual number of classes
  102. NICE::Matrix confusionMatrix(3, 3, 0.0);
  103. NICE::Timer t;
  104. double testTime (0.0);
  105. double uncertainty;
  106. int i_loopEnd ( (int)dataTest.rows() );
  107. if ( b_debug )
  108. {
  109. i_loopEnd = 1;
  110. }
  111. for (int i = 0; i < i_loopEnd ; i++)
  112. {
  113. //----------------- convert data to sparse data structures ---------
  114. NICE::SparseVector * example = new NICE::SparseVector( dataTest.getRow(i) );
  115. int result;
  116. NICE::SparseVector scores;
  117. // and classify
  118. t.start();
  119. classifier->classify( example, result, scores );
  120. t.stop();
  121. testTime += t.getLast();
  122. std::cerr << " scores.size(): " << scores.size() << std::endl;
  123. scores.store(std::cerr);
  124. if ( b_debug )
  125. {
  126. classifier->predictUncertainty( example, uncertainty );
  127. std::cerr << " uncertainty: " << uncertainty << std::endl;
  128. }
  129. else
  130. {
  131. confusionMatrix(result, yMultiTest[i]) += 1.0;
  132. }
  133. }
  134. if ( !b_debug )
  135. {
  136. std::cerr << "Time for testing: " << testTime << std::endl;
  137. confusionMatrix.normalizeColumnsL1();
  138. std::cerr << confusionMatrix << std::endl;
  139. std::cerr << "average recognition rate: " << confusionMatrix.trace()/confusionMatrix.rows() << std::endl;
  140. }
  141. return 0;
  142. }