GPHIKClassifierMex.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. #ifdef NICE_USELIB_MEX
  2. /**
  3. * @file GPHIKClassifierMex.cpp
  4. * @author Alexander Freytag
  5. * @date 07-01-2014 (dd-mm-yyyy)
  6. * @brief Matlab-Interface of our GPHIKClassifier, allowing for training, classification, optimization, variance prediction, incremental learning, and storing/re-storing.
  7. */
  8. // STL includes
  9. #include <math.h>
  10. #include <matrix.h>
  11. #include <mex.h>
  12. // NICE-core includes
  13. #include <core/basics/Config.h>
  14. #include <core/basics/Timer.h>
  15. #include <core/vector/MatrixT.h>
  16. #include <core/vector/VectorT.h>
  17. // gp-hik-core includes
  18. #include "gp-hik-core/GPHIKClassifier.h"
  19. // Interface for conversion between Matlab and C objects
  20. #include "gp-hik-core/matlab/classHandleMtoC.h"
  21. #include "gp-hik-core/matlab/ConverterMatlabToNICE.h"
  22. #include "gp-hik-core/matlab/ConverterNICEToMatlab.h"
  23. using namespace std; //C basics
  24. using namespace NICE; // nice-core
  25. NICE::Config parseParametersGPHIKClassifier(const mxArray *prhs[], int nrhs)
  26. {
  27. NICE::Config conf;
  28. // if first argument is the filename of an existing config file,
  29. // read the config accordingly
  30. int i_start ( 0 );
  31. std::string variable = MatlabConversion::convertMatlabToString(prhs[i_start]);
  32. if(variable == "conf")
  33. {
  34. conf = NICE::Config ( MatlabConversion::convertMatlabToString( prhs[i_start+1] ) );
  35. i_start = i_start+2;
  36. }
  37. // now run over all given parameter specifications
  38. // and add them to the config
  39. for( int i=i_start; i < nrhs; i+=2 )
  40. {
  41. std::string variable = MatlabConversion::convertMatlabToString(prhs[i]);
  42. /////////////////////////////////////////
  43. // READ STANDARD BOOLEAN VARIABLES
  44. /////////////////////////////////////////
  45. if( (variable == "verboseTime") ||
  46. (variable == "verbose") ||
  47. (variable == "debug") ||
  48. (variable == "optimize_noise") ||
  49. (variable == "uncertaintyPredictionForClassification") ||
  50. (variable == "use_quantization") ||
  51. (variable == "ils_verbose")
  52. )
  53. {
  54. if ( mxIsChar( prhs[i+1] ) )
  55. {
  56. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  57. if ( (value != "true") && (value != "false") )
  58. {
  59. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. In string modus, \'true\' or \'false\' expected.";
  60. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  61. }
  62. if( value == "true" )
  63. conf.sB("GPHIKClassifier", variable, true);
  64. else
  65. conf.sB("GPHIKClassifier", variable, false);
  66. }
  67. else if ( mxIsLogical( prhs[i+1] ) )
  68. {
  69. bool value = MatlabConversion::convertMatlabToBool( prhs[i+1] );
  70. conf.sB("GPHIKClassifier", variable, value);
  71. }
  72. else
  73. {
  74. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. \'true\', \'false\', or logical expected.";
  75. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  76. }
  77. }
  78. /////////////////////////////////////////
  79. // READ STANDARD INT VARIABLES
  80. /////////////////////////////////////////
  81. if ( (variable == "nrOfEigenvaluesToConsiderForVarApprox")
  82. )
  83. {
  84. if ( mxIsDouble( prhs[i+1] ) )
  85. {
  86. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  87. conf.sI("GPHIKClassifier", variable, (int) value);
  88. }
  89. else if ( mxIsInt32( prhs[i+1] ) )
  90. {
  91. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  92. conf.sI("GPHIKClassifier", variable, value);
  93. }
  94. else
  95. {
  96. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  97. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  98. }
  99. }
  100. /////////////////////////////////////////
  101. // READ STRICT POSITIVE INT VARIABLES
  102. /////////////////////////////////////////
  103. if ( (variable == "num_bins") ||
  104. (variable == "ils_max_iterations") ||
  105. (variable == "eig_value_max_iterations")
  106. )
  107. {
  108. if ( mxIsDouble( prhs[i+1] ) )
  109. {
  110. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  111. if( value < 1 )
  112. {
  113. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  114. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  115. }
  116. conf.sI("GPHIKClassifier", variable, (int) value);
  117. }
  118. else if ( mxIsInt32( prhs[i+1] ) )
  119. {
  120. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  121. if( value < 1 )
  122. {
  123. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  124. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  125. }
  126. conf.sI("GPHIKClassifier", variable, value);
  127. }
  128. else
  129. {
  130. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  131. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  132. }
  133. }
  134. /////////////////////////////////////////
  135. // READ STANDARD DOUBLE VARIABLES
  136. /////////////////////////////////////////
  137. if ( (variable == "parameter_upper_bound") ||
  138. (variable == "parameter_lower_bound")
  139. )
  140. {
  141. if ( mxIsDouble( prhs[i+1] ) )
  142. {
  143. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  144. conf.sD("GPHIKClassifier", variable, value);
  145. }
  146. else
  147. {
  148. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  149. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  150. }
  151. }
  152. /////////////////////////////////////////
  153. // READ POSITIVE DOUBLE VARIABLES
  154. /////////////////////////////////////////
  155. if ( (variable == "ils_min_delta") ||
  156. (variable == "ils_min_residual") ||
  157. (variable == "noise")
  158. )
  159. {
  160. if ( mxIsDouble( prhs[i+1] ) )
  161. {
  162. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  163. if( value < 0.0 )
  164. {
  165. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  166. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  167. }
  168. conf.sD("GPHIKClassifier", variable, value);
  169. }
  170. else
  171. {
  172. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  173. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  174. }
  175. }
  176. /////////////////////////////////////////
  177. // READ REMAINING SPECIFIC VARIABLES
  178. /////////////////////////////////////////
  179. if(variable == "ils_method")
  180. {
  181. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  182. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  183. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  184. conf.sS("GPHIKClassifier", variable, value);
  185. }
  186. if(variable == "optimization_method")
  187. {
  188. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  189. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  190. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  191. conf.sS("GPHIKClassifier", variable, value);
  192. }
  193. if(variable == "s_quantType")
  194. {
  195. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  196. if( value != "1d-aequi-0-1" && value != "1d-aequi-0-max" && value != "nd-aequi-0-max" )
  197. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'s_quantType\'. \'1d-aequi-0-1\' , \'1d-aequi-0-max\' or \'nd-aequi-0-max\' expected.");
  198. conf.sS("GPHIKClassifier", variable, value);
  199. }
  200. if(variable == "transform")
  201. {
  202. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  203. if( value != "identity" && value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  204. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'identity\', \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  205. conf.sS("GPHIKClassifier", variable, value);
  206. }
  207. if(variable == "varianceApproximation")
  208. {
  209. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  210. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  211. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  212. conf.sS("GPHIKClassifier", variable, value);
  213. }
  214. }
  215. return conf;
  216. }
  217. // MAIN MATLAB FUNCTION
  218. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  219. {
  220. // get the command string specifying what to do
  221. if (nrhs < 1)
  222. mexErrMsgTxt("No commands and options passed... Aborting!");
  223. if( !mxIsChar( prhs[0] ) )
  224. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  225. std::string cmd = MatlabConversion::convertMatlabToString( prhs[0] );
  226. // create object
  227. if ( !strcmp("new", cmd.c_str() ) )
  228. {
  229. // check output variable
  230. if (nlhs != 1)
  231. mexErrMsgTxt("New: One output expected.");
  232. // read config settings
  233. NICE::Config conf = parseParametersGPHIKClassifier(prhs+1,nrhs-1);
  234. // create class instance
  235. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf, "GPHIKClassifier" /*sectionName in config*/ );
  236. // handle to the C++ instance
  237. plhs[0] = MatlabConversion::convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  238. return;
  239. }
  240. // in all other cases, there should be a second input,
  241. // which the be the class instance handle
  242. if (nrhs < 2)
  243. mexErrMsgTxt("Second input should be a class instance handle.");
  244. // delete object
  245. if ( !strcmp("delete", cmd.c_str() ) )
  246. {
  247. // Destroy the C++ object
  248. MatlabConversion::destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  249. return;
  250. }
  251. // get the class instance pointer from the second input
  252. // every following function needs the classifier object
  253. NICE::GPHIKClassifier * classifier = MatlabConversion::convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  254. ////////////////////////////////////////
  255. // Check which class method to call //
  256. ////////////////////////////////////////
  257. // standard train - assumes initialized object
  258. if (!strcmp("train", cmd.c_str() ))
  259. {
  260. // Check parameters
  261. if (nlhs < 0 || nrhs < 4)
  262. {
  263. mexErrMsgTxt("Train: Unexpected arguments.");
  264. }
  265. //------------- read the data --------------
  266. std::vector< const NICE::SparseVector *> examplesTrain;
  267. NICE::Vector yMultiTrain;
  268. if ( mxIsSparse( prhs[2] ) )
  269. {
  270. examplesTrain = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  271. }
  272. else
  273. {
  274. NICE::Matrix dataTrain;
  275. dataTrain = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  276. //----------------- convert data to sparse data structures ---------
  277. examplesTrain.resize( dataTrain.rows() );
  278. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  279. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  280. {
  281. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  282. }
  283. }
  284. yMultiTrain = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  285. //----------------- train our classifier -------------
  286. classifier->train ( examplesTrain , yMultiTrain );
  287. //----------------- clean up -------------
  288. for(int i=0;i<examplesTrain.size();i++)
  289. delete examplesTrain[i];
  290. return;
  291. }
  292. // Classify
  293. if ( !strcmp("classify", cmd.c_str() ) )
  294. {
  295. // Check parameters
  296. if ( (nlhs < 0) || (nrhs < 2) )
  297. {
  298. mexErrMsgTxt("Test: Unexpected arguments.");
  299. }
  300. //------------- read the data --------------
  301. uint result;
  302. NICE::SparseVector scores;
  303. double uncertainty;
  304. if ( mxIsSparse( prhs[2] ) )
  305. {
  306. NICE::SparseVector * example;
  307. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  308. classifier->classify ( example, result, scores, uncertainty );
  309. //----------------- clean up -------------
  310. delete example;
  311. }
  312. else
  313. {
  314. NICE::Vector * example;
  315. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  316. classifier->classify ( example, result, scores, uncertainty );
  317. //----------------- clean up -------------
  318. delete example;
  319. }
  320. // output
  321. plhs[0] = mxCreateDoubleScalar( result );
  322. if(nlhs >= 2)
  323. {
  324. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  325. }
  326. if(nlhs >= 3)
  327. {
  328. plhs[2] = mxCreateDoubleScalar( uncertainty );
  329. }
  330. return;
  331. }
  332. // Uncertainty prediction
  333. if ( !strcmp("uncertainty", cmd.c_str() ) )
  334. {
  335. // Check parameters
  336. if ( (nlhs < 0) || (nrhs < 2) )
  337. {
  338. mexErrMsgTxt("Test: Unexpected arguments.");
  339. }
  340. double uncertainty;
  341. //------------- read the data --------------
  342. if ( mxIsSparse( prhs[2] ) )
  343. {
  344. NICE::SparseVector * example;
  345. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  346. classifier->predictUncertainty( example, uncertainty );
  347. //----------------- clean up -------------
  348. delete example;
  349. }
  350. else
  351. {
  352. NICE::Vector * example;
  353. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  354. classifier->predictUncertainty( example, uncertainty );
  355. //----------------- clean up -------------
  356. delete example;
  357. }
  358. // output
  359. plhs[0] = mxCreateDoubleScalar( uncertainty );
  360. return;
  361. }
  362. // Test - evaluate classifier on whole test set
  363. if ( !strcmp("test", cmd.c_str() ) )
  364. {
  365. // Check parameters
  366. if (nlhs < 0 || nrhs < 4)
  367. mexErrMsgTxt("Test: Unexpected arguments.");
  368. //------------- read the data --------------
  369. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  370. std::vector< const NICE::SparseVector *> dataTest_sparse;
  371. NICE::Matrix dataTest_dense;
  372. if ( dataIsSparse )
  373. {
  374. dataTest_sparse = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  375. }
  376. else
  377. {
  378. dataTest_dense = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  379. }
  380. NICE::Vector yMultiTest;
  381. yMultiTest = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  382. // ------------------------------------------
  383. // ------------- PREPARATION --------------
  384. // ------------------------------------------
  385. // determine classes known during training and corresponding mapping
  386. // thereby allow for non-continous class labels
  387. std::set< uint > classesKnownTraining = classifier->getKnownClassNumbers();
  388. uint noClassesKnownTraining ( classesKnownTraining.size() );
  389. std::map< uint, uint > mapClNoToIdxTrain;
  390. std::set< uint >::const_iterator clTrIt = classesKnownTraining.begin();
  391. for ( uint i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  392. mapClNoToIdxTrain.insert ( std::pair< uint, uint > ( *clTrIt, i ) );
  393. // determine classes known during testing and corresponding mapping
  394. // thereby allow for non-continous class labels
  395. std::set< uint > classesKnownTest;
  396. classesKnownTest.clear();
  397. // determine which classes we have in our label vector
  398. // -> MATLAB: myClasses = unique(y);
  399. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  400. {
  401. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  402. {
  403. classesKnownTest.insert ( *it );
  404. }
  405. }
  406. int noClassesKnownTest ( classesKnownTest.size() );
  407. std::map< uint, uint> mapClNoToIdxTest;
  408. std::set< uint >::const_iterator clTestIt = classesKnownTest.begin();
  409. for ( uint i=0; i < noClassesKnownTest; i++, clTestIt++ )
  410. mapClNoToIdxTest.insert ( std::pair< uint, uint > ( *clTestIt, i ) );
  411. int i_numTestSamples;
  412. if ( dataIsSparse )
  413. i_numTestSamples = dataTest_sparse.size();
  414. else
  415. i_numTestSamples = (int) dataTest_dense.rows();
  416. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  417. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  418. // ------------------------------------------
  419. // ------------- CLASSIFICATION --------------
  420. // ------------------------------------------
  421. NICE::Timer t;
  422. double testTime (0.0);
  423. for (int i = 0; i < i_numTestSamples; i++)
  424. {
  425. //----------------- convert data to sparse data structures ---------
  426. uint result;
  427. NICE::SparseVector exampleScoresSparse;
  428. if ( dataIsSparse )
  429. {
  430. // and classify
  431. t.start();
  432. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  433. t.stop();
  434. testTime += t.getLast();
  435. }
  436. else
  437. {
  438. NICE::Vector example ( dataTest_dense.getRow(i) );
  439. // and classify
  440. t.start();
  441. classifier->classify( &example, result, exampleScoresSparse );
  442. t.stop();
  443. testTime += t.getLast();
  444. }
  445. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  446. int scoreCnt ( 0 );
  447. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  448. {
  449. scores(i,scoreCnt) = scoreIt->second;
  450. }
  451. }
  452. std::cerr << "Time for testing: " << testTime << std::endl;
  453. // clean up
  454. if ( dataIsSparse )
  455. {
  456. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  457. delete *it;
  458. }
  459. confusionMatrix.normalizeColumnsL1();
  460. double recRate = confusionMatrix.trace()/confusionMatrix.cols();
  461. plhs[0] = mxCreateDoubleScalar( recRate );
  462. if(nlhs >= 2)
  463. plhs[1] = MatlabConversion::convertMatrixFromNice(confusionMatrix);
  464. if(nlhs >= 3)
  465. plhs[2] = MatlabConversion::convertMatrixFromNice(scores);
  466. return;
  467. }
  468. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  469. // interface specific methods for incremental extensions
  470. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  471. // addExample
  472. if ( !strcmp("addExample", cmd.c_str() ) )
  473. {
  474. // Check parameters
  475. if ( (nlhs < 0) || (nrhs < 4) )
  476. {
  477. mexErrMsgTxt("Test: Unexpected arguments.");
  478. }
  479. //------------- read the data --------------
  480. NICE::SparseVector * newExample;
  481. double newLabel;
  482. if ( mxIsSparse( prhs[2] ) )
  483. {
  484. newExample = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  485. }
  486. else
  487. {
  488. NICE::Vector * example;
  489. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  490. newExample = new NICE::SparseVector ( *example );
  491. //----------------- clean up -------------
  492. delete example;
  493. }
  494. newLabel = MatlabConversion::convertMatlabToDouble( prhs[3] );
  495. // setting performOptimizationAfterIncrement is optional
  496. if ( nrhs > 4 )
  497. {
  498. bool performOptimizationAfterIncrement;
  499. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  500. classifier->addExample ( newExample, newLabel, performOptimizationAfterIncrement );
  501. }
  502. else
  503. {
  504. classifier->addExample ( newExample, newLabel );
  505. }
  506. //----------------- clean up -------------
  507. delete newExample;
  508. return;
  509. }
  510. // addMultipleExamples
  511. if ( !strcmp("addMultipleExamples", cmd.c_str() ) )
  512. {
  513. // Check parameters
  514. if ( (nlhs < 0) || (nrhs < 4) )
  515. {
  516. mexErrMsgTxt("Test: Unexpected arguments.");
  517. }
  518. //------------- read the data --------------
  519. std::vector< const NICE::SparseVector *> newExamples;
  520. NICE::Vector newLabels;
  521. if ( mxIsSparse( prhs[2] ) )
  522. {
  523. newExamples = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  524. }
  525. else
  526. {
  527. NICE::Matrix newData;
  528. newData = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  529. //----------------- convert data to sparse data structures ---------
  530. newExamples.resize( newData.rows() );
  531. std::vector< const NICE::SparseVector *>::iterator exTrainIt = newExamples.begin();
  532. for (int i = 0; i < (int)newData.rows(); i++, exTrainIt++)
  533. {
  534. *exTrainIt = new NICE::SparseVector( newData.getRow(i) );
  535. }
  536. }
  537. newLabels = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  538. // setting performOptimizationAfterIncrement is optional
  539. if ( nrhs > 4 )
  540. {
  541. bool performOptimizationAfterIncrement;
  542. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  543. classifier->addMultipleExamples ( newExamples, newLabels, performOptimizationAfterIncrement );
  544. }
  545. else
  546. {
  547. classifier->addMultipleExamples ( newExamples, newLabels );
  548. }
  549. //----------------- clean up -------------
  550. for ( std::vector< const NICE::SparseVector *>::iterator exIt = newExamples.begin();
  551. exIt != newExamples.end(); exIt++
  552. )
  553. {
  554. delete *exIt;
  555. }
  556. return;
  557. }
  558. ///////////////////// INTERFACE PERSISTENT /////////////////////
  559. // interface specific methods for store and restore
  560. ///////////////////// INTERFACE PERSISTENT /////////////////////
  561. // store the classifier to an external file
  562. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  563. {
  564. // Check parameters
  565. if ( nrhs < 3 )
  566. mexErrMsgTxt("store: no destination given.");
  567. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  568. std::filebuf fb;
  569. fb.open ( s_destination.c_str(), ios::out );
  570. std::ostream os(&fb);
  571. //
  572. classifier->store( os );
  573. //
  574. fb.close();
  575. return;
  576. }
  577. // load classifier from external file
  578. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  579. {
  580. // Check parameters
  581. if ( nrhs < 3 )
  582. mexErrMsgTxt("restore: no destination given.");
  583. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  584. std::cerr << " aim at restoring the classifier from " << s_destination << std::endl;
  585. std::filebuf fbIn;
  586. fbIn.open ( s_destination.c_str(), ios::in );
  587. std::istream is (&fbIn);
  588. //
  589. classifier->restore( is );
  590. //
  591. fbIn.close();
  592. return;
  593. }
  594. // Got here, so command not recognized
  595. std::string errorMsg (cmd.c_str() );
  596. errorMsg += " -- command not recognized.";
  597. mexErrMsgTxt( errorMsg.c_str() );
  598. }
  599. #endif