GPHIKClassifierMex.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. #ifdef NICE_USELIB_MEX
  2. /**
  3. * @file GPHIKClassifierMex.cpp
  4. * @author Alexander Freytag
  5. * @date 07-01-2014 (dd-mm-yyyy)
  6. * @brief Matlab-Interface of our GPHIKClassifier, allowing for training, classification, optimization, variance prediction, incremental learning, and storing/re-storing.
  7. */
  8. // STL includes
  9. #include <math.h>
  10. #include <matrix.h>
  11. #include <mex.h>
  12. // NICE-core includes
  13. #include <core/basics/Config.h>
  14. #include <core/basics/Timer.h>
  15. #include <core/vector/MatrixT.h>
  16. #include <core/vector/VectorT.h>
  17. // gp-hik-core includes
  18. #include "gp-hik-core/GPHIKClassifier.h"
  19. // Interface for conversion between Matlab and C objects
  20. #include "gp-hik-core/matlab/classHandleMtoC.h"
  21. #include "gp-hik-core/matlab/ConverterMatlabToNICE.h"
  22. #include "gp-hik-core/matlab/ConverterNICEToMatlab.h"
  23. using namespace std; //C basics
  24. using namespace NICE; // nice-core
  25. NICE::Config parseParametersGPHIKClassifier(const mxArray *prhs[], int nrhs)
  26. {
  27. NICE::Config conf;
  28. // if first argument is the filename of an existing config file,
  29. // read the config accordingly
  30. int i_start ( 0 );
  31. std::string variable = MatlabConversion::convertMatlabToString(prhs[i_start]);
  32. if(variable == "conf")
  33. {
  34. conf = NICE::Config ( MatlabConversion::convertMatlabToString( prhs[i_start+1] ) );
  35. i_start = i_start+2;
  36. }
  37. // now run over all given parameter specifications
  38. // and add them to the config
  39. for( int i=i_start; i < nrhs; i+=2 )
  40. {
  41. std::string variable = MatlabConversion::convertMatlabToString(prhs[i]);
  42. /////////////////////////////////////////
  43. // READ STANDARD BOOLEAN VARIABLES
  44. /////////////////////////////////////////
  45. if( (variable == "verboseTime") ||
  46. (variable == "verbose") ||
  47. (variable == "debug") ||
  48. (variable == "optimize_noise") ||
  49. (variable == "uncertaintyPredictionForClassification") ||
  50. (variable == "use_quantization") ||
  51. (variable == "ils_verbose")
  52. )
  53. {
  54. if ( mxIsChar( prhs[i+1] ) )
  55. {
  56. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  57. if ( (value != "true") && (value != "false") )
  58. {
  59. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. In string modus, \'true\' or \'false\' expected.";
  60. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  61. }
  62. if( value == "true" )
  63. conf.sB("GPHIKClassifier", variable, true);
  64. else
  65. conf.sB("GPHIKClassifier", variable, false);
  66. }
  67. else if ( mxIsLogical( prhs[i+1] ) )
  68. {
  69. bool value = MatlabConversion::convertMatlabToBool( prhs[i+1] );
  70. conf.sB("GPHIKClassifier", variable, value);
  71. }
  72. else
  73. {
  74. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. \'true\', \'false\', or logical expected.";
  75. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  76. }
  77. }
  78. /////////////////////////////////////////
  79. // READ STANDARD INT VARIABLES
  80. /////////////////////////////////////////
  81. if ( (variable == "nrOfEigenvaluesToConsiderForVarApprox")
  82. )
  83. {
  84. if ( mxIsDouble( prhs[i+1] ) )
  85. {
  86. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  87. conf.sI("GPHIKClassifier", variable, (int) value);
  88. }
  89. else if ( mxIsInt32( prhs[i+1] ) )
  90. {
  91. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  92. conf.sI("GPHIKClassifier", variable, value);
  93. }
  94. else
  95. {
  96. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  97. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  98. }
  99. }
  100. /////////////////////////////////////////
  101. // READ STRICT POSITIVE INT VARIABLES
  102. /////////////////////////////////////////
  103. if ( (variable == "num_bins") ||
  104. (variable == "ils_max_iterations") ||
  105. (variable == "eig_value_max_iterations") ||
  106. (variable == "downhillsimplex_max_iterations")
  107. )
  108. {
  109. if ( mxIsDouble( prhs[i+1] ) )
  110. {
  111. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  112. if( value < 1 )
  113. {
  114. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  115. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  116. }
  117. conf.sI("GPHIKClassifier", variable, (int) value);
  118. }
  119. else if ( mxIsInt32( prhs[i+1] ) )
  120. {
  121. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  122. if( value < 1 )
  123. {
  124. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  125. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  126. }
  127. conf.sI("GPHIKClassifier", variable, value);
  128. }
  129. else
  130. {
  131. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  132. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  133. }
  134. }
  135. /////////////////////////////////////////
  136. // READ STANDARD DOUBLE VARIABLES
  137. /////////////////////////////////////////
  138. if ( (variable == "parameter_upper_bound") ||
  139. (variable == "parameter_lower_bound")
  140. )
  141. {
  142. if ( mxIsDouble( prhs[i+1] ) )
  143. {
  144. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  145. conf.sD("GPHIKClassifier", variable, value);
  146. }
  147. else
  148. {
  149. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  150. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  151. }
  152. }
  153. /////////////////////////////////////////
  154. // READ POSITIVE DOUBLE VARIABLES
  155. /////////////////////////////////////////
  156. if ( (variable == "ils_min_delta") ||
  157. (variable == "ils_min_residual") ||
  158. (variable == "noise")
  159. )
  160. {
  161. if ( mxIsDouble( prhs[i+1] ) )
  162. {
  163. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  164. if( value < 0.0 )
  165. {
  166. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  167. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  168. }
  169. conf.sD("GPHIKClassifier", variable, value);
  170. }
  171. else
  172. {
  173. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  174. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  175. }
  176. }
  177. /////////////////////////////////////////
  178. // READ REMAINING SPECIFIC VARIABLES
  179. /////////////////////////////////////////
  180. if(variable == "ils_method")
  181. {
  182. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  183. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  184. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  185. conf.sS("GPHIKClassifier", variable, value);
  186. }
  187. if(variable == "optimization_method")
  188. {
  189. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  190. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  191. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  192. conf.sS("GPHIKClassifier", variable, value);
  193. }
  194. if(variable == "s_quantType")
  195. {
  196. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  197. if( value != "1d-aequi-0-1" && value != "1d-aequi-0-max" && value != "nd-aequi-0-max" )
  198. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'s_quantType\'. \'1d-aequi-0-1\' , \'1d-aequi-0-max\' or \'nd-aequi-0-max\' expected.");
  199. conf.sS("GPHIKClassifier", variable, value);
  200. }
  201. if(variable == "transform")
  202. {
  203. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  204. if( value != "identity" && value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  205. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'identity\', \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  206. conf.sS("GPHIKClassifier", variable, value);
  207. }
  208. if(variable == "varianceApproximation")
  209. {
  210. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  211. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  212. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  213. conf.sS("GPHIKClassifier", variable, value);
  214. }
  215. }
  216. return conf;
  217. }
  218. // MAIN MATLAB FUNCTION
  219. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  220. {
  221. // get the command string specifying what to do
  222. if (nrhs < 1)
  223. mexErrMsgTxt("No commands and options passed... Aborting!");
  224. if( !mxIsChar( prhs[0] ) )
  225. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  226. std::string cmd = MatlabConversion::convertMatlabToString( prhs[0] );
  227. // create object
  228. if ( !strcmp("new", cmd.c_str() ) )
  229. {
  230. // check output variable
  231. if (nlhs != 1)
  232. mexErrMsgTxt("New: One output expected.");
  233. // read config settings
  234. NICE::Config conf = parseParametersGPHIKClassifier(prhs+1,nrhs-1);
  235. // create class instance
  236. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf, "GPHIKClassifier" /*sectionName in config*/ );
  237. // handle to the C++ instance
  238. plhs[0] = MatlabConversion::convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  239. return;
  240. }
  241. // in all other cases, there should be a second input,
  242. // which the be the class instance handle
  243. if (nrhs < 2)
  244. mexErrMsgTxt("Second input should be a class instance handle.");
  245. // delete object
  246. if ( !strcmp("delete", cmd.c_str() ) )
  247. {
  248. // Destroy the C++ object
  249. MatlabConversion::destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  250. return;
  251. }
  252. // get the class instance pointer from the second input
  253. // every following function needs the classifier object
  254. NICE::GPHIKClassifier * classifier = MatlabConversion::convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  255. ////////////////////////////////////////
  256. // Check which class method to call //
  257. ////////////////////////////////////////
  258. // standard train - assumes initialized object
  259. if (!strcmp("train", cmd.c_str() ))
  260. {
  261. // Check parameters
  262. if (nlhs < 0 || nrhs < 4)
  263. {
  264. mexErrMsgTxt("Train: Unexpected arguments.");
  265. }
  266. //------------- read the data --------------
  267. std::vector< const NICE::SparseVector *> examplesTrain;
  268. NICE::Vector yMultiTrain;
  269. if ( mxIsSparse( prhs[2] ) )
  270. {
  271. examplesTrain = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  272. }
  273. else
  274. {
  275. NICE::Matrix dataTrain;
  276. dataTrain = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  277. //----------------- convert data to sparse data structures ---------
  278. examplesTrain.resize( dataTrain.rows() );
  279. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  280. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  281. {
  282. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  283. }
  284. }
  285. yMultiTrain = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  286. //----------------- train our classifier -------------
  287. classifier->train ( examplesTrain , yMultiTrain );
  288. //----------------- clean up -------------
  289. for(int i=0;i<examplesTrain.size();i++)
  290. delete examplesTrain[i];
  291. return;
  292. }
  293. // Classify
  294. if ( !strcmp("classify", cmd.c_str() ) )
  295. {
  296. // Check parameters
  297. if ( (nlhs < 0) || (nrhs < 2) )
  298. {
  299. mexErrMsgTxt("Test: Unexpected arguments.");
  300. }
  301. if ( mxIsSparse( prhs[2] ) )
  302. {
  303. if ( MatlabConversion::isSparseDataAMatrix( prhs[2] ) )
  304. {
  305. //----------------- conversion -------------
  306. std::vector< const NICE::SparseVector *> examplesTest;
  307. examplesTest = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  308. //----------------- classification -------------
  309. NICE::Vector results;
  310. NICE::Matrix scores;
  311. NICE::Vector uncertainties;
  312. classifier->classify ( examplesTest, results, scores, uncertainties );
  313. //----------------- clean up -------------
  314. for ( std::vector< const NICE::SparseVector *>::iterator exIt = examplesTest.begin();
  315. exIt != examplesTest.end();
  316. exIt++
  317. )
  318. {
  319. delete *exIt;
  320. }
  321. //----------------- output -------------
  322. plhs[0] = MatlabConversion::convertVectorFromNice( results );
  323. if(nlhs >= 2)
  324. {
  325. plhs[1] = MatlabConversion::convertMatrixFromNice( scores );
  326. }
  327. if(nlhs >= 3)
  328. {
  329. plhs[2] = MatlabConversion::convertVectorFromNice( uncertainties );
  330. }
  331. return;
  332. }
  333. else
  334. {
  335. //----------------- conversion -------------
  336. NICE::SparseVector * example;
  337. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  338. //----------------- classification -------------
  339. uint result;
  340. NICE::SparseVector scores;
  341. double uncertainty;
  342. classifier->classify ( example, result, scores, uncertainty );
  343. //----------------- clean up -------------
  344. delete example;
  345. //----------------- output -------------
  346. plhs[0] = mxCreateDoubleScalar( result );
  347. if(nlhs >= 2)
  348. {
  349. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  350. }
  351. if(nlhs >= 3)
  352. {
  353. plhs[2] = mxCreateDoubleScalar( uncertainty );
  354. }
  355. return;
  356. }
  357. }
  358. else
  359. {
  360. //----------------- conversion -------------
  361. NICE::Vector * example;
  362. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  363. //----------------- classification -------------
  364. uint result;
  365. NICE::SparseVector scores;
  366. double uncertainty;
  367. classifier->classify ( example, result, scores, uncertainty );
  368. //----------------- clean up -------------
  369. delete example;
  370. //----------------- output -------------
  371. plhs[0] = mxCreateDoubleScalar( result );
  372. if(nlhs >= 2)
  373. {
  374. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  375. }
  376. if(nlhs >= 3)
  377. {
  378. plhs[2] = mxCreateDoubleScalar( uncertainty );
  379. }
  380. return;
  381. }
  382. }
  383. // Uncertainty prediction
  384. if ( !strcmp("uncertainty", cmd.c_str() ) )
  385. {
  386. // Check parameters
  387. if ( (nlhs < 0) || (nrhs < 2) )
  388. {
  389. mexErrMsgTxt("Test: Unexpected arguments.");
  390. }
  391. double uncertainty;
  392. //------------- read the data --------------
  393. if ( mxIsSparse( prhs[2] ) )
  394. {
  395. NICE::SparseVector * example;
  396. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  397. classifier->predictUncertainty( example, uncertainty );
  398. //----------------- clean up -------------
  399. delete example;
  400. }
  401. else
  402. {
  403. NICE::Vector * example;
  404. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  405. classifier->predictUncertainty( example, uncertainty );
  406. //----------------- clean up -------------
  407. delete example;
  408. }
  409. // output
  410. plhs[0] = mxCreateDoubleScalar( uncertainty );
  411. return;
  412. }
  413. // Test - evaluate classifier on whole test set
  414. if ( !strcmp("test", cmd.c_str() ) )
  415. {
  416. // Check parameters
  417. if (nlhs < 0 || nrhs < 4)
  418. mexErrMsgTxt("Test: Unexpected arguments.");
  419. //------------- read the data --------------
  420. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  421. std::vector< const NICE::SparseVector *> dataTest_sparse;
  422. NICE::Matrix dataTest_dense;
  423. if ( dataIsSparse )
  424. {
  425. dataTest_sparse = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  426. }
  427. else
  428. {
  429. dataTest_dense = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  430. }
  431. NICE::Vector yMultiTest;
  432. yMultiTest = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  433. // ------------------------------------------
  434. // ------------- PREPARATION --------------
  435. // ------------------------------------------
  436. // determine classes known during training and corresponding mapping
  437. // thereby allow for non-continous class labels
  438. std::set< uint > classesKnownTraining = classifier->getKnownClassNumbers();
  439. uint noClassesKnownTraining ( classesKnownTraining.size() );
  440. std::map< uint, uint > mapClNoToIdxTrain;
  441. std::set< uint >::const_iterator clTrIt = classesKnownTraining.begin();
  442. for ( uint i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  443. mapClNoToIdxTrain.insert ( std::pair< uint, uint > ( *clTrIt, i ) );
  444. // determine classes known during testing and corresponding mapping
  445. // thereby allow for non-continous class labels
  446. std::set< uint > classesKnownTest;
  447. classesKnownTest.clear();
  448. // determine which classes we have in our label vector
  449. // -> MATLAB: myClasses = unique(y);
  450. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  451. {
  452. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  453. {
  454. classesKnownTest.insert ( *it );
  455. }
  456. }
  457. int noClassesKnownTest ( classesKnownTest.size() );
  458. std::map< uint, uint> mapClNoToIdxTest;
  459. std::set< uint >::const_iterator clTestIt = classesKnownTest.begin();
  460. for ( uint i=0; i < noClassesKnownTest; i++, clTestIt++ )
  461. mapClNoToIdxTest.insert ( std::pair< uint, uint > ( *clTestIt, i ) );
  462. int i_numTestSamples;
  463. if ( dataIsSparse )
  464. i_numTestSamples = dataTest_sparse.size();
  465. else
  466. i_numTestSamples = (int) dataTest_dense.rows();
  467. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  468. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  469. // ------------------------------------------
  470. // ------------- CLASSIFICATION --------------
  471. // ------------------------------------------
  472. NICE::Timer t;
  473. double testTime (0.0);
  474. for (int i = 0; i < i_numTestSamples; i++)
  475. {
  476. //----------------- convert data to sparse data structures ---------
  477. uint result;
  478. NICE::SparseVector exampleScoresSparse;
  479. if ( dataIsSparse )
  480. {
  481. // and classify
  482. t.start();
  483. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  484. t.stop();
  485. testTime += t.getLast();
  486. }
  487. else
  488. {
  489. NICE::Vector example ( dataTest_dense.getRow(i) );
  490. // and classify
  491. t.start();
  492. classifier->classify( &example, result, exampleScoresSparse );
  493. t.stop();
  494. testTime += t.getLast();
  495. }
  496. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  497. int scoreCnt ( 0 );
  498. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  499. {
  500. scores(i,scoreCnt) = scoreIt->second;
  501. }
  502. }
  503. // clean up
  504. if ( dataIsSparse )
  505. {
  506. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  507. delete *it;
  508. }
  509. confusionMatrix.normalizeColumnsL1();
  510. double recRate = confusionMatrix.trace()/confusionMatrix.cols();
  511. plhs[0] = mxCreateDoubleScalar( recRate );
  512. if(nlhs >= 2)
  513. plhs[1] = MatlabConversion::convertMatrixFromNice(confusionMatrix);
  514. if(nlhs >= 3)
  515. plhs[2] = MatlabConversion::convertMatrixFromNice(scores);
  516. return;
  517. }
  518. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  519. // interface specific methods for incremental extensions
  520. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  521. // addExample
  522. if ( !strcmp("addExample", cmd.c_str() ) )
  523. {
  524. // Check parameters
  525. if ( (nlhs < 0) || (nrhs < 4) )
  526. {
  527. mexErrMsgTxt("Test: Unexpected arguments.");
  528. }
  529. //------------- read the data --------------
  530. NICE::SparseVector * newExample;
  531. double newLabel;
  532. if ( mxIsSparse( prhs[2] ) )
  533. {
  534. newExample = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  535. }
  536. else
  537. {
  538. NICE::Vector * example;
  539. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  540. newExample = new NICE::SparseVector ( *example );
  541. //----------------- clean up -------------
  542. delete example;
  543. }
  544. newLabel = MatlabConversion::convertMatlabToDouble( prhs[3] );
  545. // setting performOptimizationAfterIncrement is optional
  546. if ( nrhs > 4 )
  547. {
  548. bool performOptimizationAfterIncrement;
  549. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  550. classifier->addExample ( newExample, newLabel, performOptimizationAfterIncrement );
  551. }
  552. else
  553. {
  554. classifier->addExample ( newExample, newLabel );
  555. }
  556. //----------------- clean up -------------
  557. delete newExample;
  558. return;
  559. }
  560. // addMultipleExamples
  561. if ( !strcmp("addMultipleExamples", cmd.c_str() ) )
  562. {
  563. // Check parameters
  564. if ( (nlhs < 0) || (nrhs < 4) )
  565. {
  566. mexErrMsgTxt("Test: Unexpected arguments.");
  567. }
  568. //------------- read the data --------------
  569. std::vector< const NICE::SparseVector *> newExamples;
  570. NICE::Vector newLabels;
  571. if ( mxIsSparse( prhs[2] ) )
  572. {
  573. newExamples = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  574. }
  575. else
  576. {
  577. NICE::Matrix newData;
  578. newData = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  579. //----------------- convert data to sparse data structures ---------
  580. newExamples.resize( newData.rows() );
  581. std::vector< const NICE::SparseVector *>::iterator exTrainIt = newExamples.begin();
  582. for (int i = 0; i < (int)newData.rows(); i++, exTrainIt++)
  583. {
  584. *exTrainIt = new NICE::SparseVector( newData.getRow(i) );
  585. }
  586. }
  587. newLabels = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  588. // setting performOptimizationAfterIncrement is optional
  589. if ( nrhs > 4 )
  590. {
  591. bool performOptimizationAfterIncrement;
  592. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  593. classifier->addMultipleExamples ( newExamples, newLabels, performOptimizationAfterIncrement );
  594. }
  595. else
  596. {
  597. classifier->addMultipleExamples ( newExamples, newLabels );
  598. }
  599. //----------------- clean up -------------
  600. for ( std::vector< const NICE::SparseVector *>::iterator exIt = newExamples.begin();
  601. exIt != newExamples.end(); exIt++
  602. )
  603. {
  604. delete *exIt;
  605. }
  606. return;
  607. }
  608. ///////////////////// INTERFACE PERSISTENT /////////////////////
  609. // interface specific methods for store and restore
  610. ///////////////////// INTERFACE PERSISTENT /////////////////////
  611. // store the classifier to an external file
  612. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  613. {
  614. // Check parameters
  615. if ( nrhs < 3 )
  616. mexErrMsgTxt("store: no destination given.");
  617. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  618. std::filebuf fb;
  619. fb.open ( s_destination.c_str(), ios::out );
  620. std::ostream os(&fb);
  621. //
  622. classifier->store( os );
  623. //
  624. fb.close();
  625. return;
  626. }
  627. // load classifier from external file
  628. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  629. {
  630. // Check parameters
  631. if ( nrhs < 3 )
  632. mexErrMsgTxt("restore: no destination given.");
  633. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  634. std::cerr << " aim at restoring the classifier from " << s_destination << std::endl;
  635. std::filebuf fbIn;
  636. fbIn.open ( s_destination.c_str(), ios::in );
  637. std::istream is (&fbIn);
  638. //
  639. classifier->restore( is );
  640. //
  641. fbIn.close();
  642. return;
  643. }
  644. // Got here, so command not recognized
  645. std::string errorMsg (cmd.c_str() );
  646. errorMsg += " -- command not recognized.";
  647. mexErrMsgTxt( errorMsg.c_str() );
  648. }
  649. #endif