testFPClassifier.cpp 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * @file testFPClassifier.cpp
  3. * @brief main program for classifier evaluation
  4. * @author Erik Rodner
  5. * @date 2007-10-12
  6. */
  7. #include <fstream>
  8. #include <iostream>
  9. #include <core/basics/numerictools.h>
  10. #include <core/basics/Config.h>
  11. #include <core/basics/StringTools.h>
  12. //----------
  13. #include <vislearning/baselib/Preprocess.h>
  14. #include <vislearning/cbaselib/MultiDataset.h>
  15. #include <vislearning/cbaselib/ClassificationResults.h>
  16. #include <vislearning/cbaselib/MutualInformation.h>
  17. #include <vislearning/classifier/classifierbase/FeaturePoolClassifier.h>
  18. #include <vislearning/classifier/fpclassifier/gphik/FPCGPHIK.h>
  19. #include <vislearning/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h>
  20. #include <vislearning/classifier/classifierinterfaces/VCFeaturePool.h>
  21. #include <vislearning/math/cluster/GMM.h>
  22. //----------
  23. #undef DEBUG
  24. using namespace OBJREC;
  25. using namespace NICE;
  26. using namespace std;
  27. int main ( int argc, char **argv )
  28. {
  29. fprintf ( stderr, "testClassifier: init\n" );
  30. std::set_terminate ( __gnu_cxx::__verbose_terminate_handler );
  31. Config conf ( argc, argv );
  32. FPCGPHIK *classifier = new FPCGPHIK ( &conf, "ClassiferGPHIK" );
  33. string trainfn = conf.gS ( "data", "trainfile" );
  34. string testfn = conf.gS ( "data", "testfile" );
  35. Examples trainex;
  36. ifstream intrain ( trainfn.c_str() );
  37. int parts = 0;
  38. while ( intrain.good() )
  39. {
  40. string line;
  41. getline ( intrain, line );
  42. vector<string> split;
  43. StringTools::split ( line, ' ', split );
  44. if ( split.size() == 0 )
  45. break;
  46. if ( parts > 0 )
  47. assert ( parts == ( int ) split.size() );
  48. parts = split.size();
  49. int classno = atoi ( split[0].c_str() );
  50. SparseVector *sv = new SparseVector();
  51. for ( uint i = 1; i < split.size();i++ )
  52. {
  53. vector<string> split2;
  54. StringTools::split ( split[i], ':', split2 );
  55. assert ( split2.size() == 2 );
  56. ( *sv ) [atoi ( split2[0].c_str() ) ] = atof ( split2[1].c_str() );
  57. }
  58. Example example;
  59. example.vec = NULL;
  60. example.svec = sv;
  61. trainex.push_back ( pair<int, Example> ( classno, example ) );
  62. }
  63. cout << "trainex.size(): " << trainex.size() << endl;
  64. Examples testex;
  65. ifstream intest ( testfn.c_str() );
  66. parts = 0;
  67. while ( intest.good() )
  68. {
  69. string line;
  70. getline ( intest, line );
  71. vector<string> split;
  72. StringTools::split ( line, ' ', split );
  73. if ( split.size() == 0 )
  74. break;
  75. if ( parts > 0 )
  76. assert ( parts == ( int ) split.size() );
  77. parts = split.size();
  78. int classno = atoi ( split[0].c_str() );
  79. SparseVector *sv = new SparseVector();
  80. for ( uint i = 1; i < split.size();i++ )
  81. {
  82. vector<string> split2;
  83. StringTools::split ( split[i], ':', split2 );
  84. assert ( split2.size() == 2 );
  85. double val = atof (split2[1].c_str());
  86. if(val != 0.0)
  87. ( *sv ) [atoi ( split2[0].c_str() ) ] = val;
  88. }
  89. Example example;
  90. example.vec = NULL;
  91. example.svec = sv;
  92. testex.push_back ( pair<int, Example> ( classno, example ) );
  93. }
  94. cout << "testex.size(): " << testex.size() << endl;
  95. FeaturePool fp;
  96. classifier->train ( fp, trainex );
  97. int counter = 0;
  98. for ( uint e = 0; e < testex.size(); e++ )
  99. {
  100. ClassificationResult r = classifier->classify ( testex[e].second );
  101. int bestclass = 0;
  102. double bestval = r.scores[0];
  103. for ( int j = 1 ; j < r.scores.size(); j++ )
  104. {
  105. if(r.scores[j] > bestval)
  106. {
  107. bestclass = j;
  108. bestval = r.scores[j];
  109. }
  110. }
  111. if(bestclass == testex[e].first)
  112. counter++;
  113. }
  114. cout << "avg: " << (double)counter/(double)testex.size() << endl;
  115. return 0;
  116. }