GPHIKRawClassifierMex.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. #ifdef NICE_USELIB_MEX
  2. /**
  3. * @file GPHIKRawClassifierMex.cpp
  4. * @author Alexander Freytag
  5. * @date 21-09-2015 (dd-mm-yyyy)
  6. * @brief Matlab-Interface of our GPHIKRawClassifier, allowing for training and classification without more advanced methods.
  7. */
  8. // STL includes
  9. #include <math.h>
  10. #include <matrix.h>
  11. #include <mex.h>
  12. // NICE-core includes
  13. #include <core/basics/Config.h>
  14. #include <core/basics/Timer.h>
  15. #include <core/vector/MatrixT.h>
  16. #include <core/vector/VectorT.h>
  17. // gp-hik-core includes
  18. #include "gp-hik-core/GPHIKRawClassifier.h"
  19. // Interface for conversion between Matlab and C objects
  20. #include "gp-hik-core/matlab/classHandleMtoC.h"
  21. #include "gp-hik-core/matlab/ConverterMatlabToNICE.h"
  22. #include "gp-hik-core/matlab/ConverterNICEToMatlab.h"
  23. using namespace std; //C basics
  24. using namespace NICE; // nice-core
  25. NICE::Config parseParametersGPHIKRawClassifier(const mxArray *prhs[], int nrhs)
  26. {
  27. NICE::Config conf;
  28. // if first argument is the filename of an existing config file,
  29. // read the config accordingly
  30. int i_start ( 0 );
  31. std::string variable = MatlabConversion::convertMatlabToString(prhs[i_start]);
  32. if(variable == "conf")
  33. {
  34. conf = NICE::Config ( MatlabConversion::convertMatlabToString( prhs[i_start+1] ) );
  35. i_start = i_start+2;
  36. }
  37. // now run over all given parameter specifications
  38. // and add them to the config
  39. for( int i=i_start; i < nrhs; i+=2 )
  40. {
  41. std::string variable = MatlabConversion::convertMatlabToString(prhs[i]);
  42. /////////////////////////////////////////
  43. // READ STANDARD BOOLEAN VARIABLES
  44. /////////////////////////////////////////
  45. if( (variable == "verbose") ||
  46. (variable == "debug") ||
  47. (variable == "use_quantization") ||
  48. (variable == "ils_verbose")
  49. )
  50. {
  51. if ( mxIsChar( prhs[i+1] ) )
  52. {
  53. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  54. if ( (value != "true") && (value != "false") )
  55. {
  56. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. In string modus, \'true\' or \'false\' expected.";
  57. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  58. }
  59. if( value == "true" )
  60. conf.sB("GPHIKRawClassifier", variable, true);
  61. else
  62. conf.sB("GPHIKRawClassifier", variable, false);
  63. }
  64. else if ( mxIsLogical( prhs[i+1] ) )
  65. {
  66. bool value = MatlabConversion::convertMatlabToBool( prhs[i+1] );
  67. conf.sB("GPHIKRawClassifier", variable, value);
  68. }
  69. else
  70. {
  71. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. \'true\', \'false\', or logical expected.";
  72. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  73. }
  74. }
  75. /////////////////////////////////////////
  76. // READ STANDARD INT VARIABLES
  77. /////////////////////////////////////////
  78. /////////////////////////////////////////
  79. // READ STRICT POSITIVE INT VARIABLES
  80. /////////////////////////////////////////
  81. if ( (variable == "num_bins") ||
  82. ( variable == "ils_max_iterations" )||
  83. ( variable == "eig_value_max_iterations" )
  84. )
  85. {
  86. if ( mxIsDouble( prhs[i+1] ) )
  87. {
  88. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  89. if( value < 1 )
  90. {
  91. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  92. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  93. }
  94. conf.sI("GPHIKRawClassifier", variable, (int) value);
  95. }
  96. else if ( mxIsInt32( prhs[i+1] ) )
  97. {
  98. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  99. if( value < 1 )
  100. {
  101. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  102. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  103. }
  104. conf.sI("GPHIKRawClassifier", variable, value);
  105. }
  106. else
  107. {
  108. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  109. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  110. }
  111. }
  112. /////////////////////////////////////////
  113. // READ STANDARD DOUBLE VARIABLES
  114. /////////////////////////////////////////
  115. /////////////////////////////////////////
  116. // READ POSITIVE DOUBLE VARIABLES
  117. /////////////////////////////////////////
  118. if ( (variable == "f_tolerance") ||
  119. (variable == "ils_min_delta") ||
  120. (variable == "ils_min_residual") ||
  121. (variable == "noise")
  122. )
  123. {
  124. if ( mxIsDouble( prhs[i+1] ) )
  125. {
  126. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  127. if( value < 0.0 )
  128. {
  129. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  130. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  131. }
  132. conf.sD("GPHIKRawClassifier", variable, value);
  133. }
  134. else
  135. {
  136. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  137. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  138. }
  139. }
  140. /////////////////////////////////////////
  141. // READ REMAINING SPECIFIC VARIABLES
  142. /////////////////////////////////////////
  143. if(variable == "ils_method")
  144. {
  145. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  146. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  147. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  148. conf.sS("GPHIKRawClassifier", variable, value);
  149. }
  150. if(variable == "s_quantType")
  151. {
  152. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  153. if( value != "1d-aequi-0-1" && value != "1d-aequi-0-max" && value != "nd-aequi-0-max" )
  154. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'s_quantType\'. \'1d-aequi-0-1\' , \'1d-aequi-0-max\' or \'nd-aequi-0-max\' expected.");
  155. conf.sS("GPHIKRawClassifier", variable, value);
  156. }
  157. }
  158. return conf;
  159. }
  160. // MAIN MATLAB FUNCTION
  161. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  162. {
  163. // get the command string specifying what to do
  164. if (nrhs < 1)
  165. mexErrMsgTxt("No commands and options passed... Aborting!");
  166. if( !mxIsChar( prhs[0] ) )
  167. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  168. std::string cmd = MatlabConversion::convertMatlabToString( prhs[0] );
  169. // create object
  170. if ( !strcmp("new", cmd.c_str() ) )
  171. {
  172. // check output variable
  173. if (nlhs != 1)
  174. mexErrMsgTxt("New: One output expected.");
  175. // read config settings
  176. NICE::Config conf = parseParametersGPHIKRawClassifier(prhs+1,nrhs-1);
  177. // create class instance
  178. NICE::GPHIKRawClassifier * classifier = new NICE::GPHIKRawClassifier ( &conf, "GPHIKRawClassifier" /*sectionName in config*/ );
  179. // handle to the C++ instance
  180. plhs[0] = MatlabConversion::convertPtr2Mat<NICE::GPHIKRawClassifier>( classifier );
  181. return;
  182. }
  183. // in all other cases, there should be a second input,
  184. // which the be the class instance handle
  185. if (nrhs < 2)
  186. mexErrMsgTxt("Second input should be a class instance handle.");
  187. // delete object
  188. if ( !strcmp("delete", cmd.c_str() ) )
  189. {
  190. // Destroy the C++ object
  191. MatlabConversion::destroyObject<NICE::GPHIKRawClassifier>(prhs[1]);
  192. return;
  193. }
  194. // get the class instance pointer from the second input
  195. // every following function needs the classifier object
  196. NICE::GPHIKRawClassifier * classifier = MatlabConversion::convertMat2Ptr<NICE::GPHIKRawClassifier>(prhs[1]);
  197. ////////////////////////////////////////
  198. // Check which class method to call //
  199. ////////////////////////////////////////
  200. // standard train - assumes initialized object
  201. if (!strcmp("train", cmd.c_str() ))
  202. {
  203. // Check parameters
  204. if (nlhs < 0 || nrhs < 4)
  205. {
  206. mexErrMsgTxt("Train: Unexpected arguments.");
  207. }
  208. //------------- read the data --------------
  209. std::vector< const NICE::SparseVector *> examplesTrain;
  210. NICE::Vector yMultiTrain;
  211. if ( mxIsSparse( prhs[2] ) )
  212. {
  213. examplesTrain = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  214. }
  215. else
  216. {
  217. NICE::Matrix dataTrain;
  218. dataTrain = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  219. //----------------- convert data to sparse data structures ---------
  220. examplesTrain.resize( dataTrain.rows() );
  221. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  222. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  223. {
  224. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  225. }
  226. }
  227. yMultiTrain = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  228. //----------------- train our classifier -------------
  229. NICE::Timer t;
  230. t.start();
  231. classifier->train ( examplesTrain , yMultiTrain );
  232. t.stop();
  233. FILE * pFile;
  234. pFile = fopen ("/home/freytag/experiments/gphik/2015-08-28-gphik-classif/training_times_gphik_raw.txt","a");
  235. fprintf (pFile, "GPHIKRaw-Mex -- Time for training without data conversion: %f \n",t.getLast());
  236. fclose (pFile);
  237. //----------------- clean up -------------
  238. for(int i=0;i<examplesTrain.size();i++)
  239. delete examplesTrain[i];
  240. return;
  241. }
  242. // Classify
  243. if ( !strcmp("classify", cmd.c_str() ) )
  244. {
  245. // Check parameters
  246. if ( (nlhs < 0) || (nrhs < 2) )
  247. {
  248. mexErrMsgTxt("Test: Unexpected arguments.");
  249. }
  250. if ( mxIsSparse( prhs[2] ) )
  251. {
  252. if ( MatlabConversion::isSparseDataAMatrix( prhs[2] ) )
  253. {
  254. //----------------- conversion -------------
  255. std::vector< const NICE::SparseVector *> examplesTest;
  256. examplesTest = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  257. //----------------- classification -------------
  258. NICE::Vector results;
  259. NICE::Matrix scores;
  260. NICE::Timer t;
  261. t.start();
  262. classifier->classify ( examplesTest, results, scores );
  263. t.stop();
  264. FILE * pFile;
  265. pFile = fopen ("/home/freytag/experiments/gphik/2015-08-28-gphik-classif/test_times_gphik_raw_quant.txt","a");
  266. fprintf (pFile, "GPHIKRaw-Mex-Quant test time %f \n",t.getLast());
  267. fclose (pFile);
  268. //----------------- clean up -------------
  269. for ( std::vector< const NICE::SparseVector *>::iterator exIt = examplesTest.begin();
  270. exIt != examplesTest.end();
  271. exIt++
  272. )
  273. {
  274. delete *exIt;
  275. }
  276. //----------------- output -------------
  277. plhs[0] = MatlabConversion::convertVectorFromNice( results );
  278. if(nlhs >= 2)
  279. {
  280. plhs[1] = MatlabConversion::convertMatrixFromNice( scores );
  281. }
  282. return;
  283. }
  284. else
  285. {
  286. //----------------- conversion -------------
  287. NICE::SparseVector * example;
  288. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  289. //----------------- classification -------------
  290. uint result;
  291. NICE::SparseVector scores;
  292. NICE::Timer t;
  293. t.start();
  294. classifier->classify ( example, result, scores );
  295. t.stop();
  296. FILE * pFile;
  297. pFile = fopen ("/home/freytag/experiments/gphik/2015-08-28-gphik-classif/test_times_gphik_raw_quant.txt","a");
  298. fprintf (pFile, "GPHIKRaw-Mex-Quant test time %f \n",t.getLast());
  299. fclose (pFile);
  300. //----------------- clean up -------------
  301. delete example;
  302. //----------------- output -------------
  303. plhs[0] = mxCreateDoubleScalar( result );
  304. if(nlhs >= 2)
  305. {
  306. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  307. }
  308. return;
  309. }
  310. }
  311. else
  312. {
  313. //----------------- conversion -------------
  314. NICE::Vector * example;
  315. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  316. NICE::SparseVector * svec = new NICE::SparseVector( *example );
  317. delete example;
  318. //----------------- classification -------------
  319. uint result;
  320. NICE::SparseVector scores;
  321. classifier->classify ( svec, result, scores );
  322. //----------------- clean up -------------
  323. delete svec;
  324. //----------------- output -------------
  325. plhs[0] = mxCreateDoubleScalar( result );
  326. if(nlhs >= 2)
  327. {
  328. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  329. }
  330. return;
  331. }
  332. }
  333. // Test - evaluate classifier on whole test set
  334. if ( !strcmp("test", cmd.c_str() ) )
  335. {
  336. // Check parameters
  337. if (nlhs < 0 || nrhs < 4)
  338. mexErrMsgTxt("Test: Unexpected arguments.");
  339. //------------- read the data --------------
  340. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  341. std::vector< const NICE::SparseVector *> dataTest_sparse;
  342. NICE::Matrix dataTest_dense;
  343. if ( dataIsSparse )
  344. {
  345. dataTest_sparse = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  346. }
  347. else
  348. {
  349. dataTest_dense = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  350. }
  351. NICE::Vector yMultiTest;
  352. yMultiTest = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  353. // ------------------------------------------
  354. // ------------- PREPARATION --------------
  355. // ------------------------------------------
  356. // determine classes known during training and corresponding mapping
  357. // thereby allow for non-continous class labels
  358. std::set< uint > classesKnownTraining = classifier->getKnownClassNumbers();
  359. uint noClassesKnownTraining ( classesKnownTraining.size() );
  360. std::map< uint, uint > mapClNoToIdxTrain;
  361. std::set< uint >::const_iterator clTrIt = classesKnownTraining.begin();
  362. for ( uint i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  363. mapClNoToIdxTrain.insert ( std::pair< uint, uint > ( *clTrIt, i ) );
  364. // determine classes known during testing and corresponding mapping
  365. // thereby allow for non-continous class labels
  366. std::set< uint > classesKnownTest;
  367. classesKnownTest.clear();
  368. // determine which classes we have in our label vector
  369. // -> MATLAB: myClasses = unique(y);
  370. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  371. {
  372. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  373. {
  374. classesKnownTest.insert ( *it );
  375. }
  376. }
  377. int noClassesKnownTest ( classesKnownTest.size() );
  378. std::map< uint, uint> mapClNoToIdxTest;
  379. std::set< uint >::const_iterator clTestIt = classesKnownTest.begin();
  380. for ( uint i=0; i < noClassesKnownTest; i++, clTestIt++ )
  381. mapClNoToIdxTest.insert ( std::pair< uint, uint > ( *clTestIt, i ) );
  382. int i_numTestSamples;
  383. if ( dataIsSparse )
  384. i_numTestSamples = dataTest_sparse.size();
  385. else
  386. i_numTestSamples = (int) dataTest_dense.rows();
  387. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  388. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  389. // ------------------------------------------
  390. // ------------- CLASSIFICATION --------------
  391. // ------------------------------------------
  392. NICE::Timer t;
  393. double testTime (0.0);
  394. for (int i = 0; i < i_numTestSamples; i++)
  395. {
  396. //----------------- convert data to sparse data structures ---------
  397. uint result;
  398. NICE::SparseVector exampleScoresSparse;
  399. if ( dataIsSparse )
  400. {
  401. // and classify
  402. t.start();
  403. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  404. t.stop();
  405. testTime += t.getLast();
  406. }
  407. else
  408. {
  409. NICE::Vector example ( dataTest_dense.getRow(i) );
  410. NICE::SparseVector * svec = new NICE::SparseVector ( example );
  411. // and classify
  412. t.start();
  413. classifier->classify( svec, result, exampleScoresSparse );
  414. t.stop();
  415. testTime += t.getLast();
  416. delete svec;
  417. }
  418. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  419. int scoreCnt ( 0 );
  420. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  421. {
  422. scores(i,scoreCnt) = scoreIt->second;
  423. }
  424. }
  425. std::cerr << "Time for testing: " << testTime << std::endl;
  426. // clean up
  427. if ( dataIsSparse )
  428. {
  429. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  430. delete *it;
  431. }
  432. confusionMatrix.normalizeColumnsL1();
  433. double recRate = confusionMatrix.trace()/confusionMatrix.cols();
  434. plhs[0] = mxCreateDoubleScalar( recRate );
  435. if(nlhs >= 2)
  436. plhs[1] = MatlabConversion::convertMatrixFromNice(confusionMatrix);
  437. if(nlhs >= 3)
  438. plhs[2] = MatlabConversion::convertMatrixFromNice(scores);
  439. return;
  440. }
  441. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  442. // interface specific methods for incremental extensions
  443. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  444. // not supported here
  445. ///////////////////// INTERFACE PERSISTENT /////////////////////
  446. // interface specific methods for store and restore
  447. ///////////////////// INTERFACE PERSISTENT /////////////////////
  448. // not supported here
  449. // Got here, so command not recognized
  450. std::string errorMsg (cmd.c_str() );
  451. errorMsg += " -- command not recognized.";
  452. mexErrMsgTxt( errorMsg.c_str() );
  453. }
  454. #endif