toyExample.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /**
  2. * @file toyExample.cpp
  3. * @brief Demo-Program to show how to call some methods of the GPHIKClassifier class
  4. * @author Alexander Freytag
  5. * @date 19-10-2012
  6. */
  7. #include <iostream>
  8. #include <vector>
  9. #include <core/basics/Config.h>
  10. #include <core/basics/Timer.h>
  11. #include <core/vector/MatrixT.h>
  12. #include <core/vector/VectorT.h>
  13. #include "gp-hik-core/GPHIKClassifier.h"
  14. using namespace std; //C basics
  15. using namespace NICE; // nice-core
  16. int main (int argc, char* argv[])
  17. {
  18. Config conf ( argc, argv );
  19. std::string trainData = conf.gS( "main", "trainData", "progs/toyExampleSmallScaleTrain.data" );
  20. //------------- read the training data --------------
  21. NICE::Matrix dataTrain;
  22. NICE::Vector yBinTrain;
  23. NICE::Vector yMultiTrain;
  24. std::ifstream ifsTrain ( trainData.c_str() , ios::in );
  25. if (ifsTrain.good() )
  26. {
  27. ifsTrain >> dataTrain;
  28. ifsTrain >> yBinTrain;
  29. ifsTrain >> yMultiTrain;
  30. ifsTrain.close();
  31. }
  32. else
  33. {
  34. std::cerr << "Unable to read training data, aborting." << std::endl;
  35. return -1;
  36. }
  37. //----------------- convert data to sparse data structures ---------
  38. std::vector< NICE::SparseVector *> examplesTrain;
  39. examplesTrain.resize( dataTrain.rows() );
  40. std::vector< NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  41. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  42. {
  43. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  44. }
  45. //----------------- train our classifier -------------
  46. conf.sB("GPHIKClassifier", "verbose", false);
  47. GPHIKClassifier * classifier = new GPHIKClassifier ( &conf );
  48. classifier->train ( examplesTrain , yMultiTrain );
  49. // ------------------------------------------
  50. // ------------- CLASSIFICATION --------------
  51. // ------------------------------------------
  52. //------------- read the test data --------------
  53. NICE::Matrix dataTest;
  54. NICE::Vector yBinTest;
  55. NICE::Vector yMultiTest;
  56. std::string testData = conf.gS( "main", "testData", "progs/toyExampleTest.data" );
  57. std::ifstream ifsTest ( testData.c_str(), ios::in );
  58. if (ifsTest.good() )
  59. {
  60. ifsTest >> dataTest;
  61. ifsTest >> yBinTest;
  62. ifsTest >> yMultiTest;
  63. ifsTest.close();
  64. }
  65. else
  66. {
  67. std::cerr << "Unable to read test data, aborting." << std::endl;
  68. return -1;
  69. }
  70. //TODO adapt this to the actual number of classes
  71. NICE::Matrix confusionMatrix(3, 3, 0.0);
  72. NICE::Timer t;
  73. double testTime (0.0);
  74. for (int i = 0; i < (int)dataTest.rows(); i++)
  75. {
  76. //----------------- convert data to sparse data structures ---------
  77. NICE::SparseVector * example = new NICE::SparseVector( dataTest.getRow(i) );
  78. int result;
  79. NICE::SparseVector scores;
  80. // and classify
  81. t.start();
  82. classifier->classify( example, result, scores );
  83. t.stop();
  84. testTime += t.getLast();
  85. confusionMatrix(result, yMultiTest[i]) += 1.0;
  86. }
  87. std::cerr << "Time for testing: " << testTime << std::endl;
  88. confusionMatrix.normalizeColumnsL1();
  89. std::cerr << confusionMatrix << std::endl;
  90. std::cerr << "average recognition rate: " << confusionMatrix.trace()/confusionMatrix.rows() << std::endl;
  91. return 0;
  92. }