GPHIKClassifierMex.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. /**
  2. * @file GPHIKClassifierMex.cpp
  3. * @author Alexander Freytag
  4. * @date 07-01-2014 (dd-mm-yyyy)
  5. * @brief Matlab-Interface of our GPHIKClassifier, allowing for training, classification, optimization, variance prediction, incremental learning, and storing/re-storing.
  6. */
  7. // STL includes
  8. #include <math.h>
  9. #include <matrix.h>
  10. #include <mex.h>
  11. // NICE-core includes
  12. #include <core/basics/Config.h>
  13. #include <core/basics/Timer.h>
  14. #include <core/vector/MatrixT.h>
  15. #include <core/vector/VectorT.h>
  16. // gp-hik-core includes
  17. #include "gp-hik-core/GPHIKClassifier.h"
  18. // Interface for conversion between Matlab and C objects
  19. #include "gp-hik-core/matlab/classHandleMtoC.h"
  20. #include "gp-hik-core/matlab/ConverterMatlabToNICE.h"
  21. #include "gp-hik-core/matlab/ConverterNICEToMatlab.h"
  22. using namespace std; //C basics
  23. using namespace NICE; // nice-core
  24. NICE::Config parseParametersGPHIKClassifier(const mxArray *prhs[], int nrhs)
  25. {
  26. NICE::Config conf;
  27. // if first argument is the filename of an existing config file,
  28. // read the config accordingly
  29. int i_start ( 0 );
  30. std::string variable = MatlabConversion::convertMatlabToString(prhs[i_start]);
  31. if(variable == "conf")
  32. {
  33. conf = NICE::Config ( MatlabConversion::convertMatlabToString( prhs[i_start+1] ) );
  34. i_start = i_start+2;
  35. }
  36. // now run over all given parameter specifications
  37. // and add them to the config
  38. for( int i=i_start; i < nrhs; i+=2 )
  39. {
  40. std::string variable = MatlabConversion::convertMatlabToString(prhs[i]);
  41. /////////////////////////////////////////
  42. // READ STANDARD BOOLEAN VARIABLES
  43. /////////////////////////////////////////
  44. if( (variable == "verboseTime") || (variable == "verbose") ||
  45. (variable == "optimize_noise") || (variable == "uncertaintyPredictionForClassification") ||
  46. (variable == "use_quantization") || (variable == "ils_verbose")
  47. )
  48. {
  49. if ( mxIsChar( prhs[i+1] ) )
  50. {
  51. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  52. if ( (value != "true") && (value != "false") )
  53. {
  54. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. In string modus, \'true\' or \'false\' expected.";
  55. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  56. }
  57. if( value == "true" )
  58. conf.sB("GPHIKClassifier", variable, true);
  59. else
  60. conf.sB("GPHIKClassifier", variable, false);
  61. }
  62. else if ( mxIsLogical( prhs[i+1] ) )
  63. {
  64. bool value = MatlabConversion::convertMatlabToBool( prhs[i+1] );
  65. conf.sB("GPHIKClassifier", variable, value);
  66. }
  67. else
  68. {
  69. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. \'true\', \'false\', or logical expected.";
  70. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  71. }
  72. }
  73. /////////////////////////////////////////
  74. // READ STANDARD INT VARIABLES
  75. /////////////////////////////////////////
  76. if ( (variable == "nrOfEigenvaluesToConsiderForVarApprox")
  77. )
  78. {
  79. if ( mxIsDouble( prhs[i+1] ) )
  80. {
  81. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  82. conf.sI("GPHIKClassifier", variable, (int) value);
  83. }
  84. else if ( mxIsInt32( prhs[i+1] ) )
  85. {
  86. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  87. conf.sI("GPHIKClassifier", variable, value);
  88. }
  89. else
  90. {
  91. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  92. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  93. }
  94. }
  95. /////////////////////////////////////////
  96. // READ STRICT POSITIVE INT VARIABLES
  97. /////////////////////////////////////////
  98. if ( (variable == "num_bins") || (variable == "ils_max_iterations")
  99. )
  100. {
  101. if ( mxIsDouble( prhs[i+1] ) )
  102. {
  103. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  104. if( value < 1 )
  105. {
  106. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  107. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  108. }
  109. conf.sI("GPHIKClassifier", variable, (int) value);
  110. }
  111. else if ( mxIsInt32( prhs[i+1] ) )
  112. {
  113. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  114. if( value < 1 )
  115. {
  116. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  117. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  118. }
  119. conf.sI("GPHIKClassifier", variable, value);
  120. }
  121. else
  122. {
  123. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  124. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  125. }
  126. }
  127. /////////////////////////////////////////
  128. // READ POSITIVE DOUBLE VARIABLES
  129. /////////////////////////////////////////
  130. if ( (variable == "ils_min_delta") || (variable == "ils_min_residual") ||
  131. (variable == "noise")
  132. )
  133. {
  134. if ( mxIsDouble( prhs[i+1] ) )
  135. {
  136. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  137. if( value < 0.0 )
  138. {
  139. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  140. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  141. }
  142. conf.sD("GPHIKClassifier", variable, value);
  143. }
  144. else
  145. {
  146. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  147. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  148. }
  149. }
  150. /////////////////////////////////////////
  151. // READ REMAINING SPECIFIC VARIABLES
  152. /////////////////////////////////////////
  153. if(variable == "ils_method")
  154. {
  155. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  156. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  157. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  158. conf.sS("GPHIKClassifier", variable, value);
  159. }
  160. if(variable == "optimization_method")
  161. {
  162. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  163. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  164. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  165. conf.sS("GPHIKClassifier", variable, value);
  166. }
  167. if(variable == "transform")
  168. {
  169. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  170. if(value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  171. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  172. conf.sS("GPHIKClassifier", variable, value);
  173. }
  174. if(variable == "varianceApproximation")
  175. {
  176. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  177. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  178. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  179. conf.sS("GPHIKClassifier", variable, value);
  180. }
  181. }
  182. return conf;
  183. }
  184. // MAIN MATLAB FUNCTION
  185. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  186. {
  187. // get the command string specifying what to do
  188. if (nrhs < 1)
  189. mexErrMsgTxt("No commands and options passed... Aborting!");
  190. if( !mxIsChar( prhs[0] ) )
  191. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  192. std::string cmd = MatlabConversion::convertMatlabToString( prhs[0] );
  193. // create object
  194. if ( !strcmp("new", cmd.c_str() ) )
  195. {
  196. // check output variable
  197. if (nlhs != 1)
  198. mexErrMsgTxt("New: One output expected.");
  199. // read config settings
  200. NICE::Config conf = parseParametersGPHIKClassifier(prhs+1,nrhs-1);
  201. // create class instance
  202. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf, "GPHIKClassifier" /*sectionName in config*/ );
  203. // handle to the C++ instance
  204. plhs[0] = MatlabConversion::convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  205. return;
  206. }
  207. // in all other cases, there should be a second input,
  208. // which the be the class instance handle
  209. if (nrhs < 2)
  210. mexErrMsgTxt("Second input should be a class instance handle.");
  211. // delete object
  212. if ( !strcmp("delete", cmd.c_str() ) )
  213. {
  214. // Destroy the C++ object
  215. MatlabConversion::destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  216. return;
  217. }
  218. // get the class instance pointer from the second input
  219. // every following function needs the classifier object
  220. NICE::GPHIKClassifier * classifier = MatlabConversion::convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  221. ////////////////////////////////////////
  222. // Check which class method to call //
  223. ////////////////////////////////////////
  224. // standard train - assumes initialized object
  225. if (!strcmp("train", cmd.c_str() ))
  226. {
  227. // Check parameters
  228. if (nlhs < 0 || nrhs < 4)
  229. {
  230. mexErrMsgTxt("Train: Unexpected arguments.");
  231. }
  232. //------------- read the data --------------
  233. std::vector< const NICE::SparseVector *> examplesTrain;
  234. NICE::Vector yMultiTrain;
  235. if ( mxIsSparse( prhs[2] ) )
  236. {
  237. examplesTrain = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  238. }
  239. else
  240. {
  241. NICE::Matrix dataTrain;
  242. dataTrain = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  243. //----------------- convert data to sparse data structures ---------
  244. examplesTrain.resize( dataTrain.rows() );
  245. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  246. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  247. {
  248. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  249. }
  250. }
  251. yMultiTrain = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  252. //----------------- train our classifier -------------
  253. classifier->train ( examplesTrain , yMultiTrain );
  254. //----------------- clean up -------------
  255. for(int i=0;i<examplesTrain.size();i++)
  256. delete examplesTrain[i];
  257. return;
  258. }
  259. // Classify
  260. if ( !strcmp("classify", cmd.c_str() ) )
  261. {
  262. // Check parameters
  263. if ( (nlhs < 0) || (nrhs < 2) )
  264. {
  265. mexErrMsgTxt("Test: Unexpected arguments.");
  266. }
  267. //------------- read the data --------------
  268. int result;
  269. NICE::SparseVector scores;
  270. double uncertainty;
  271. if ( mxIsSparse( prhs[2] ) )
  272. {
  273. NICE::SparseVector * example;
  274. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  275. classifier->classify ( example, result, scores, uncertainty );
  276. //----------------- clean up -------------
  277. delete example;
  278. }
  279. else
  280. {
  281. NICE::Vector * example;
  282. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  283. classifier->classify ( example, result, scores, uncertainty );
  284. //----------------- clean up -------------
  285. delete example;
  286. }
  287. // output
  288. plhs[0] = mxCreateDoubleScalar( result );
  289. if(nlhs >= 2)
  290. {
  291. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  292. }
  293. if(nlhs >= 3)
  294. {
  295. plhs[2] = mxCreateDoubleScalar( uncertainty );
  296. }
  297. return;
  298. }
  299. // Uncertainty prediction
  300. if ( !strcmp("uncertainty", cmd.c_str() ) )
  301. {
  302. // Check parameters
  303. if ( (nlhs < 0) || (nrhs < 2) )
  304. {
  305. mexErrMsgTxt("Test: Unexpected arguments.");
  306. }
  307. double uncertainty;
  308. //------------- read the data --------------
  309. if ( mxIsSparse( prhs[2] ) )
  310. {
  311. NICE::SparseVector * example;
  312. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  313. classifier->predictUncertainty( example, uncertainty );
  314. //----------------- clean up -------------
  315. delete example;
  316. }
  317. else
  318. {
  319. NICE::Vector * example;
  320. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  321. classifier->predictUncertainty( example, uncertainty );
  322. //----------------- clean up -------------
  323. delete example;
  324. }
  325. // output
  326. plhs[0] = mxCreateDoubleScalar( uncertainty );
  327. return;
  328. }
  329. // Test - evaluate classifier on whole test set
  330. if ( !strcmp("test", cmd.c_str() ) )
  331. {
  332. // Check parameters
  333. if (nlhs < 0 || nrhs < 4)
  334. mexErrMsgTxt("Test: Unexpected arguments.");
  335. //------------- read the data --------------
  336. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  337. std::vector< const NICE::SparseVector *> dataTest_sparse;
  338. NICE::Matrix dataTest_dense;
  339. if ( dataIsSparse )
  340. {
  341. dataTest_sparse = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  342. }
  343. else
  344. {
  345. dataTest_dense = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  346. }
  347. NICE::Vector yMultiTest;
  348. yMultiTest = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  349. // ------------------------------------------
  350. // ------------- PREPARATION --------------
  351. // ------------------------------------------
  352. // determine classes known during training and corresponding mapping
  353. // thereby allow for non-continous class labels
  354. std::set<int> classesKnownTraining = classifier->getKnownClassNumbers();
  355. int noClassesKnownTraining ( classesKnownTraining.size() );
  356. std::map<int,int> mapClNoToIdxTrain;
  357. std::set<int>::const_iterator clTrIt = classesKnownTraining.begin();
  358. for ( int i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  359. mapClNoToIdxTrain.insert ( std::pair<int,int> ( *clTrIt, i ) );
  360. // determine classes known during testing and corresponding mapping
  361. // thereby allow for non-continous class labels
  362. std::set<int> classesKnownTest;
  363. classesKnownTest.clear();
  364. // determine which classes we have in our label vector
  365. // -> MATLAB: myClasses = unique(y);
  366. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  367. {
  368. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  369. {
  370. classesKnownTest.insert ( *it );
  371. }
  372. }
  373. int noClassesKnownTest ( classesKnownTest.size() );
  374. std::map<int,int> mapClNoToIdxTest;
  375. std::set<int>::const_iterator clTestIt = classesKnownTest.begin();
  376. for ( int i=0; i < noClassesKnownTest; i++, clTestIt++ )
  377. mapClNoToIdxTest.insert ( std::pair<int,int> ( *clTestIt, i ) );
  378. int i_numTestSamples;
  379. if ( dataIsSparse )
  380. i_numTestSamples = dataTest_sparse.size();
  381. else
  382. i_numTestSamples = (int) dataTest_dense.rows();
  383. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  384. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  385. // ------------------------------------------
  386. // ------------- CLASSIFICATION --------------
  387. // ------------------------------------------
  388. NICE::Timer t;
  389. double testTime (0.0);
  390. for (int i = 0; i < i_numTestSamples; i++)
  391. {
  392. //----------------- convert data to sparse data structures ---------
  393. int result;
  394. NICE::SparseVector exampleScoresSparse;
  395. if ( dataIsSparse )
  396. {
  397. // and classify
  398. t.start();
  399. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  400. t.stop();
  401. testTime += t.getLast();
  402. }
  403. else
  404. {
  405. NICE::Vector example ( dataTest_dense.getRow(i) );
  406. // and classify
  407. t.start();
  408. classifier->classify( &example, result, exampleScoresSparse );
  409. t.stop();
  410. testTime += t.getLast();
  411. }
  412. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  413. int scoreCnt ( 0 );
  414. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  415. {
  416. scores(i,scoreCnt) = scoreIt->second;
  417. }
  418. }
  419. std::cerr << "Time for testing: " << testTime << std::endl;
  420. // clean up
  421. if ( dataIsSparse )
  422. {
  423. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  424. delete *it;
  425. }
  426. confusionMatrix.normalizeColumnsL1();
  427. double recRate = confusionMatrix.trace()/confusionMatrix.cols();
  428. plhs[0] = mxCreateDoubleScalar( recRate );
  429. if(nlhs >= 2)
  430. plhs[1] = MatlabConversion::convertMatrixFromNice(confusionMatrix);
  431. if(nlhs >= 3)
  432. plhs[2] = MatlabConversion::convertMatrixFromNice(scores);
  433. return;
  434. }
  435. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  436. // interface specific methods for incremental extensions
  437. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  438. // addExample
  439. if ( !strcmp("addExample", cmd.c_str() ) )
  440. {
  441. // Check parameters
  442. if ( (nlhs < 0) || (nrhs < 4) )
  443. {
  444. mexErrMsgTxt("Test: Unexpected arguments.");
  445. }
  446. //------------- read the data --------------
  447. NICE::SparseVector * newExample;
  448. double newLabel;
  449. if ( mxIsSparse( prhs[2] ) )
  450. {
  451. newExample = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  452. }
  453. else
  454. {
  455. NICE::Vector * example;
  456. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  457. newExample = new NICE::SparseVector ( *example );
  458. //----------------- clean up -------------
  459. delete example;
  460. }
  461. newLabel = MatlabConversion::convertMatlabToDouble( prhs[3] );
  462. // setting performOptimizationAfterIncrement is optional
  463. if ( nrhs > 4 )
  464. {
  465. bool performOptimizationAfterIncrement;
  466. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  467. classifier->addExample ( newExample, newLabel, performOptimizationAfterIncrement );
  468. }
  469. else
  470. {
  471. classifier->addExample ( newExample, newLabel );
  472. }
  473. //----------------- clean up -------------
  474. delete newExample;
  475. return;
  476. }
  477. // addMultipleExamples
  478. if ( !strcmp("addMultipleExamples", cmd.c_str() ) )
  479. {
  480. // Check parameters
  481. if ( (nlhs < 0) || (nrhs < 4) )
  482. {
  483. mexErrMsgTxt("Test: Unexpected arguments.");
  484. }
  485. //------------- read the data --------------
  486. std::vector< const NICE::SparseVector *> newExamples;
  487. NICE::Vector newLabels;
  488. if ( mxIsSparse( prhs[2] ) )
  489. {
  490. newExamples = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  491. }
  492. else
  493. {
  494. NICE::Matrix newData;
  495. newData = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  496. //----------------- convert data to sparse data structures ---------
  497. newExamples.resize( newData.rows() );
  498. std::vector< const NICE::SparseVector *>::iterator exTrainIt = newExamples.begin();
  499. for (int i = 0; i < (int)newData.rows(); i++, exTrainIt++)
  500. {
  501. *exTrainIt = new NICE::SparseVector( newData.getRow(i) );
  502. }
  503. }
  504. newLabels = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  505. // setting performOptimizationAfterIncrement is optional
  506. if ( nrhs > 4 )
  507. {
  508. bool performOptimizationAfterIncrement;
  509. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  510. classifier->addMultipleExamples ( newExamples, newLabels, performOptimizationAfterIncrement );
  511. }
  512. else
  513. {
  514. classifier->addMultipleExamples ( newExamples, newLabels );
  515. }
  516. //----------------- clean up -------------
  517. for ( std::vector< const NICE::SparseVector *>::iterator exIt = newExamples.begin();
  518. exIt != newExamples.end(); exIt++
  519. )
  520. {
  521. delete *exIt;
  522. }
  523. return;
  524. }
  525. ///////////////////// INTERFACE PERSISTENT /////////////////////
  526. // interface specific methods for store and restore
  527. ///////////////////// INTERFACE PERSISTENT /////////////////////
  528. // store the classifier to an external file
  529. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  530. {
  531. // Check parameters
  532. if ( nrhs < 3 )
  533. mexErrMsgTxt("store: no destination given.");
  534. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  535. std::filebuf fb;
  536. fb.open ( s_destination.c_str(), ios::out );
  537. std::ostream os(&fb);
  538. //
  539. classifier->store( os );
  540. //
  541. fb.close();
  542. return;
  543. }
  544. // load classifier from external file
  545. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  546. {
  547. // Check parameters
  548. if ( nrhs < 3 )
  549. mexErrMsgTxt("restore: no destination given.");
  550. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  551. std::cerr << " aim at restoring the classifier from " << s_destination << std::endl;
  552. std::filebuf fbIn;
  553. fbIn.open ( s_destination.c_str(), ios::in );
  554. std::istream is (&fbIn);
  555. //
  556. classifier->restore( is );
  557. //
  558. fbIn.close();
  559. return;
  560. }
  561. // Got here, so command not recognized
  562. std::string errorMsg (cmd.c_str() );
  563. errorMsg += " -- command not recognized.";
  564. mexErrMsgTxt( errorMsg.c_str() );
  565. }