testImageNetBinary.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * @file testImageNetBinary.cpp
  3. * @brief perform ImageNet tests with binary tasks for OCC
  4. * @author Alexander Lütz
  5. * @date 23-05-2012 (dd-mm-yyyy)
  6. */
  7. #include "core/basics/Config.h"
  8. #ifdef NICE_USELIB_MATIO
  9. #include "vislearning/cbaselib/ClassificationResults.h"
  10. #include "vislearning/baselib/ProgressBar.h"
  11. #include "core/matlabAccess/MatFileIO.h"
  12. #include "vislearning/matlabAccessHighLevel/ImageNetData.h"
  13. #include "vislearning/classifier/kernelclassifier/KCGPOneClass.h"
  14. #include "vislearning/classifier/kernelclassifier/KCGPApproxOneClass.h"
  15. #include "vislearning/math/kernels/KernelData.h"
  16. #include "vislearning/math/kernels/Kernel.h"
  17. #include "vislearning/math/kernels/KernelRBF.h"
  18. #include "vislearning/math/kernels/KernelExp.h"
  19. // #include "fast-hik/tools.h"
  20. using namespace std;
  21. using namespace NICE;
  22. using namespace OBJREC;
  23. /**
  24. test the basic functionality of fast-hik hyperparameter optimization
  25. */
  26. int main (int argc, char **argv)
  27. {
  28. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  29. Config conf ( argc, argv );
  30. string resultsfile = conf.gS("main", "results", "results.txt" );
  31. int positiveClass = conf.gI("main", "positive_class");
  32. std::cerr << "Positive class is " << positiveClass << std::endl;
  33. sparse_t data;
  34. NICE::Vector y;
  35. std::cerr << "Reading ImageNet data ..." << std::endl;
  36. bool imageNetLocal = conf.gB("main", "imageNetLocal" , false);
  37. string imageNetPath;
  38. if (imageNetLocal)
  39. imageNetPath = "/users2/rodner/data/imagenet/devkit-1.0/";
  40. else
  41. imageNetPath = "/home/dbv/bilder/imagenet/devkit-1.0/";
  42. ImageNetData imageNet ( imageNetPath + "demo/" );
  43. // imageNet.getBatchData ( data, y, "train", "training" );
  44. LabeledSetVector train;
  45. imageNet.loadDataAsLabeledSetVector( train );
  46. //set up the kernel function
  47. double rbf_sigma = conf.gD("main", "rbf_sigma", -2.0 );
  48. KernelRBF kernelFunction ( rbf_sigma, 0.0 );
  49. //KernelExp kernelFunction ( rbf_sigma, 0.0, 0.0 );
  50. //set up our OC-classifier
  51. string classifierName = conf.gS("main", "classifier", "KCGPApproxOneClass");
  52. KernelClassifier *classifier;
  53. if(strcmp("KCGPApproxOneClass",classifierName.c_str())==0)
  54. {
  55. classifier = new KCGPApproxOneClass ( &conf, &kernelFunction );
  56. }
  57. else if (strcmp("KCGPOneClass",classifierName.c_str())==0) {
  58. classifier = new KCGPOneClass ( &conf, &kernelFunction );
  59. }
  60. else{ //default
  61. classifier = new KCGPApproxOneClass ( &conf, &kernelFunction );
  62. }
  63. //and perform the training
  64. classifier->teach( train );
  65. // uint n = y.size();
  66. //
  67. // set<int> positives;
  68. // set<int> negatives;
  69. //
  70. // map< int, set<int> > mysets;
  71. // for ( uint i = 0 ; i < n; i++ )
  72. // mysets[ y[i] ].insert ( i );
  73. //
  74. // if ( mysets[ positiveClass ].size() == 0 )
  75. // fthrow(Exception, "Class " << positiveClass << " is not available.");
  76. //
  77. // // add our positive examples
  78. // for ( set<int>::const_iterator i = mysets[positiveClass].begin(); i != mysets[positiveClass].end(); i++ )
  79. // positives.insert ( *i );
  80. //
  81. // int Nneg = conf.gI("main", "nneg", 1 );
  82. // for ( map<int, set<int> >::const_iterator k = mysets.begin(); k != mysets.end(); k++ )
  83. // {
  84. // int classno = k->first;
  85. // if ( classno == positiveClass )
  86. // continue;
  87. // const set<int> & s = k->second;
  88. // uint ind = 0;
  89. // for ( set<int>::const_iterator i = s.begin(); (i != s.end() && ind < Nneg); i++,ind++ )
  90. // negatives.insert ( *i );
  91. // }
  92. // std::cerr << "Number of positive examples: " << positives.size() << std::endl;
  93. // std::cerr << "Number of negative examples: " << negatives.size() << std::endl;
  94. // ------------------------------ TESTING ------------------------------
  95. std::cerr << "Reading ImageNet test data files (takes some seconds)..." << std::endl;
  96. imageNet.preloadData ( "val", "testing" );
  97. imageNet.loadExternalLabels ( imageNetPath + "data/ILSVRC2010_validation_ground_truth.txt" );
  98. ClassificationResults results;
  99. std::cerr << "Classification step ... with " << imageNet.getNumPreloadedExamples() << " examples" << std::endl;
  100. ProgressBar pb;
  101. for ( uint i = 0 ; i < (uint)imageNet.getNumPreloadedExamples(); i++ )
  102. {
  103. pb.update ( imageNet.getNumPreloadedExamples() );
  104. const SparseVector & svec = imageNet.getPreloadedExample ( i );
  105. NICE::Vector vec;
  106. svec.convertToVectorT( vec );
  107. // classification step
  108. ClassificationResult r = classifier->classify ( vec );
  109. // set ground truth label
  110. r.classno_groundtruth = (((int)imageNet.getPreloadedLabel ( i )) == positiveClass) ? 1 : 0;
  111. results.push_back ( r );
  112. }
  113. std::cerr << "Writing results to " << resultsfile << std::endl;
  114. results.writeWEKA ( resultsfile, 0 );
  115. double perfvalue = results.getBinaryClassPerformance( ClassificationResults::PERF_AUC );
  116. std::cerr << "Performance: " << perfvalue << std::endl;
  117. //don't waste memory
  118. delete classifier;
  119. return 0;
  120. }
  121. #else
  122. int main (int argc, char **argv)
  123. {
  124. }
  125. #endif