GPHIK.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854
  1. #include <math.h>
  2. #include <matrix.h>
  3. #include "mex.h"
  4. #include "classHandleMtoC.h"
  5. // NICE-core includes
  6. #include <core/basics/Config.h>
  7. #include <core/basics/Timer.h>
  8. #include <core/vector/MatrixT.h>
  9. #include <core/vector/VectorT.h>
  10. // gp-hik-core includes
  11. #include "gp-hik-core/GPHIKClassifier.h"
  12. using namespace std; //C basics
  13. using namespace NICE; // nice-core
  14. /* Pass analyze_sparse a pointer to a sparse mxArray. A sparse mxArray
  15. only stores its nonzero elements. The values of the nonzero elements
  16. are stored in the pr and pi arrays. The tricky part of analyzing
  17. sparse mxArray's is figuring out the indices where the nonzero
  18. elements are stored. (See the mxSetIr and mxSetJc reference pages
  19. for details. */
  20. std::vector< NICE::SparseVector * > convertSparseMatrixToNice(const mxArray *array_ptr)
  21. {
  22. double *pr;//, *pi;
  23. mwIndex *ir, *jc;
  24. mwSize col, total=0;
  25. mwIndex starting_row_index, stopping_row_index, current_row_index;
  26. mwSize i_numExamples, i_numDim;
  27. /* Get the starting positions of all four data arrays. */
  28. pr = mxGetPr(array_ptr);
  29. // pi = mxGetPi(array_ptr);
  30. ir = mxGetIr(array_ptr);
  31. jc = mxGetJc(array_ptr);
  32. // dimenions of the matrix -> feature dimension and number of examples
  33. i_numExamples = mxGetM(array_ptr);
  34. i_numDim = mxGetN(array_ptr);
  35. // initialize output variable
  36. std::vector< NICE::SparseVector * > sparseMatrix;
  37. sparseMatrix.resize ( i_numExamples );
  38. for ( std::vector< NICE::SparseVector * >::iterator matIt = sparseMatrix.begin();
  39. matIt != sparseMatrix.end(); matIt++)
  40. {
  41. *matIt = new NICE::SparseVector( i_numDim );
  42. }
  43. // now copy the data
  44. for (col=0; col < i_numDim; col++)
  45. {
  46. starting_row_index = jc[col];
  47. stopping_row_index = jc[col+1];
  48. // empty column?
  49. if (starting_row_index == stopping_row_index)
  50. continue;
  51. else
  52. {
  53. for ( current_row_index = starting_row_index;
  54. current_row_index < stopping_row_index;
  55. current_row_index++)
  56. {
  57. //note: no complex data supported her
  58. sparseMatrix[ ir[current_row_index] ]->insert( std::pair<int, double>( col, pr[total++] ) );
  59. } // for-loop
  60. }
  61. } // for-loop over columns
  62. return sparseMatrix;
  63. }
  64. // b_adaptIndexMtoC: if true, dim k will be inserted as k, not as k-1 (which would be the default for M->C)
  65. NICE::SparseVector convertSparseVectorToNice(const mxArray* array_ptr, const bool & b_adaptIndexMtoC = false )
  66. {
  67. double *pr, *pi;
  68. mwIndex *ir, *jc;
  69. mwSize col, total=0;
  70. mwIndex starting_row_index, stopping_row_index, current_row_index;
  71. mwSize dimy, dimx;
  72. /* Get the starting positions of all four data arrays. */
  73. pr = mxGetPr(array_ptr);
  74. pi = mxGetPi(array_ptr);
  75. ir = mxGetIr(array_ptr);
  76. jc = mxGetJc(array_ptr);
  77. // dimenions of the matrix -> feature dimension and number of examples
  78. dimy = mxGetM(array_ptr);
  79. dimx = mxGetN(array_ptr);
  80. double* ptr = mxGetPr(array_ptr);
  81. if(dimx != 1 && dimy != 1)
  82. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  83. NICE::SparseVector svec( std::max(dimx, dimy) );
  84. if ( dimx > 1)
  85. {
  86. for ( mwSize row=0; row < dimx; row++)
  87. {
  88. // empty column?
  89. if (jc[row] == jc[row+1])
  90. {
  91. continue;
  92. }
  93. else
  94. {
  95. //note: no complex data supported her
  96. double value ( pr[total++] );
  97. if ( b_adaptIndexMtoC )
  98. svec.insert( std::pair<int, double>( row+1, value ) );
  99. else
  100. svec.insert( std::pair<int, double>( row, value ) );
  101. }
  102. } // for loop over cols
  103. }
  104. else
  105. {
  106. mwSize numNonZero = jc[1]-jc[0];
  107. for ( mwSize colNonZero=0; colNonZero < numNonZero; colNonZero++)
  108. {
  109. //note: no complex data supported her
  110. double value ( pr[total++] );
  111. if ( b_adaptIndexMtoC )
  112. svec.insert( std::pair<int, double>( ir[colNonZero]+1, value ) );
  113. else
  114. svec.insert( std::pair<int, double>( ir[colNonZero], value ) );
  115. }
  116. }
  117. return svec;
  118. }
  119. // b_adaptIndexCtoM: if true, dim k will be inserted as k, not as k+1 (which would be the default for C->M)
  120. mxArray* convertSparseVectorFromNice( const NICE::SparseVector & scores, const bool & b_adaptIndexCtoM = false)
  121. {
  122. mxArray * matlabSparseVec = mxCreateSparse( scores.getDim() /*m*/, 1/*n*/, scores.size()/*nzmax*/, mxREAL);
  123. // To make the returned sparse mxArray useful, you must initialize the pr, ir, jc, and (if it exists) pi arrays.
  124. // mxCreateSparse allocates space for:
  125. //
  126. // A pr array of length nzmax.
  127. // A pi array of length nzmax, but only if ComplexFlag is mxCOMPLEX in C (1 in Fortran).
  128. // An ir array of length nzmax.
  129. // A jc array of length n+1.
  130. double* prPtr = mxGetPr(matlabSparseVec);
  131. mwIndex * ir = mxGetIr( matlabSparseVec );
  132. mwIndex * jc = mxGetJc( matlabSparseVec );
  133. jc[1] = scores.size(); jc[0] = 0;
  134. mwSize cnt = 0;
  135. for ( NICE::SparseVector::const_iterator myIt = scores.begin(); myIt != scores.end(); myIt++, cnt++ )
  136. {
  137. // set index
  138. if ( b_adaptIndexCtoM )
  139. ir[cnt] = myIt->first-1;
  140. else
  141. ir[cnt] = myIt->first;
  142. // set value
  143. prPtr[cnt] = myIt->second;
  144. }
  145. return matlabSparseVec;
  146. }
  147. mxArray* convertMatrixFromNice(NICE::Matrix & niceMatrix)
  148. {
  149. mxArray *matlabMatrix = mxCreateDoubleMatrix(niceMatrix.rows(),niceMatrix.cols(),mxREAL);
  150. double* matlabMatrixPtr = mxGetPr(matlabMatrix);
  151. for(int i=0; i<niceMatrix.rows(); i++)
  152. {
  153. for(int j=0; j<niceMatrix.cols(); j++)
  154. {
  155. matlabMatrixPtr[i + j*niceMatrix.rows()] = niceMatrix(i,j);
  156. }
  157. }
  158. return matlabMatrix;
  159. }
  160. NICE::Matrix convertMatrixToNice(const mxArray* matlabMatrix)
  161. {
  162. //todo: do not assume double
  163. const mwSize *dims;
  164. int dimx, dimy, numdims;
  165. //figure out dimensions
  166. dims = mxGetDimensions(matlabMatrix);
  167. numdims = mxGetNumberOfDimensions(matlabMatrix);
  168. dimy = (int)dims[0]; dimx = (int)dims[1];
  169. double* ptr = mxGetPr(matlabMatrix);
  170. NICE::Matrix niceMatrix(ptr, dimy, dimx, NICE::Matrix::external);
  171. return niceMatrix;
  172. }
  173. mxArray* convertVectorFromNice(NICE::Vector & niceVector)
  174. {
  175. //cout << "start convertVectorFromNice" << endl;
  176. mxArray *matlabVector = mxCreateDoubleMatrix(niceVector.size(), 1, mxREAL);
  177. double* matlabVectorPtr = mxGetPr(matlabVector);
  178. for(int i=0;i<niceVector.size(); i++)
  179. {
  180. matlabVectorPtr[i] = niceVector[i];
  181. }
  182. return matlabVector;
  183. }
  184. NICE::Vector convertVectorToNice(const mxArray* matlabMatrix)
  185. {
  186. //todo: do not assume double
  187. const mwSize *dims;
  188. int dimx, dimy, numdims;
  189. //figure out dimensions
  190. dims = mxGetDimensions(matlabMatrix);
  191. numdims = mxGetNumberOfDimensions(matlabMatrix);
  192. dimy = (int)dims[0]; dimx = (int)dims[1];
  193. double* ptr = mxGetPr(matlabMatrix);
  194. if(dimx != 1 && dimy != 1)
  195. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  196. int dim = max(dimx, dimy);
  197. NICE::Vector niceVector(dim, 0.0);
  198. for(int i=0;i<dim;i++)
  199. {
  200. niceVector(i) = ptr[i];
  201. }
  202. return niceVector;
  203. }
  204. std::string convertMatlabToString(const mxArray *matlabString)
  205. {
  206. if(!mxIsChar(matlabString))
  207. mexErrMsgIdAndTxt("mexnice:error","Expected string");
  208. char *cstring = mxArrayToString(matlabString);
  209. std::string s(cstring);
  210. mxFree(cstring);
  211. return s;
  212. }
  213. int convertMatlabToInt32(const mxArray *matlabInt32)
  214. {
  215. if(!mxIsInt32(matlabInt32))
  216. mexErrMsgIdAndTxt("mexnice:error","Expected int32");
  217. int* ptr = (int*)mxGetData(matlabInt32);
  218. return ptr[0];
  219. }
  220. double convertMatlabToDouble(const mxArray *matlabDouble)
  221. {
  222. if(!mxIsDouble(matlabDouble))
  223. mexErrMsgIdAndTxt("mexnice:error","Expected double");
  224. double* ptr = (double*)mxGetData(matlabDouble);
  225. return ptr[0];
  226. }
  227. NICE::Config parseParameters(const mxArray *prhs[], int nrhs)
  228. {
  229. NICE::Config conf;
  230. // if first argument is the filename of an existing config file,
  231. // read the config accordingly
  232. int i_start ( 0 );
  233. std::string variable = convertMatlabToString(prhs[i_start]);
  234. if(variable == "conf")
  235. {
  236. conf = NICE::Config ( convertMatlabToString( prhs[i_start+1] ) );
  237. i_start = i_start+2;
  238. }
  239. // now run over all given parameter specifications
  240. // and add them to the config
  241. for( int i=i_start; i < nrhs; i+=2 )
  242. {
  243. std::string variable = convertMatlabToString(prhs[i]);
  244. if(variable == "ils_verbose")
  245. {
  246. string value = convertMatlabToString(prhs[i+1]);
  247. if(value != "true" && value != "false")
  248. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_verbose\'. \'true\' or \'false\' expected.");
  249. if(value == "true")
  250. conf.sB("GPHIKClassifier", variable, true);
  251. else
  252. conf.sB("GPHIKClassifier", variable, false);
  253. }
  254. if(variable == "ils_max_iterations")
  255. {
  256. int value = convertMatlabToInt32(prhs[i+1]);
  257. if(value < 1)
  258. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_max_iterations\'.");
  259. conf.sI("GPHIKClassifier", variable, value);
  260. }
  261. if(variable == "ils_method")
  262. {
  263. string value = convertMatlabToString(prhs[i+1]);
  264. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  265. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  266. conf.sS("GPHIKClassifier", variable, value);
  267. }
  268. if(variable == "ils_min_delta")
  269. {
  270. double value = convertMatlabToDouble(prhs[i+1]);
  271. if(value < 0.0)
  272. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_min_delta\'.");
  273. conf.sD("GPHIKClassifier", variable, value);
  274. }
  275. if(variable == "ils_min_residual")
  276. {
  277. double value = convertMatlabToDouble(prhs[i+1]);
  278. if(value < 0.0)
  279. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_min_residual\'.");
  280. conf.sD("GPHIKClassifier", variable, value);
  281. }
  282. if(variable == "optimization_method")
  283. {
  284. string value = convertMatlabToString(prhs[i+1]);
  285. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  286. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  287. conf.sS("GPHIKClassifier", variable, value);
  288. }
  289. if(variable == "use_quantization")
  290. {
  291. string value = convertMatlabToString(prhs[i+1]);
  292. if(value != "true" && value != "false")
  293. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'use_quantization\'. \'true\' or \'false\' expected.");
  294. if(value == "true")
  295. conf.sB("GPHIKClassifier", variable, true);
  296. else
  297. conf.sB("GPHIKClassifier", variable, false);
  298. }
  299. if(variable == "num_bins")
  300. {
  301. int value = convertMatlabToInt32(prhs[i+1]);
  302. if(value < 1)
  303. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'num_bins\'.");
  304. conf.sI("GPHIKClassifier", variable, value);
  305. }
  306. if(variable == "transform")
  307. {
  308. string value = convertMatlabToString(prhs[i+1]);
  309. if(value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  310. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  311. conf.sS("GPHIKClassifier", variable, value);
  312. }
  313. if(variable == "verboseTime")
  314. {
  315. string value = convertMatlabToString(prhs[i+1]);
  316. if(value != "true" && value != "false")
  317. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'verboseTime\'. \'true\' or \'false\' expected.");
  318. if(value == "true")
  319. conf.sB("GPHIKClassifier", variable, true);
  320. else
  321. conf.sB("GPHIKClassifier", variable, false);
  322. }
  323. if(variable == "verbose")
  324. {
  325. string value = convertMatlabToString(prhs[i+1]);
  326. if(value != "true" && value != "false")
  327. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'verbose\'. \'true\' or \'false\' expected.");
  328. if(value == "true")
  329. conf.sB("GPHIKClassifier", variable, true);
  330. else
  331. conf.sB("GPHIKClassifier", variable, false);
  332. }
  333. if(variable == "noise")
  334. {
  335. double value = convertMatlabToDouble(prhs[i+1]);
  336. if(value < 0.0)
  337. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value larger than 0 for \'noise\'.");
  338. conf.sD("GPHIKClassifier", variable, value);
  339. }
  340. if(variable == "optimize_noise")
  341. {
  342. string value = convertMatlabToString(prhs[i+1]);
  343. if(value != "true" && value != "false")
  344. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimize_noise\'. \'true\' or \'false\' expected.");
  345. if(value == "true")
  346. conf.sB("GPHIKClassifier", variable, true);
  347. else
  348. conf.sB("GPHIKClassifier", variable, false);
  349. }
  350. if(variable == "varianceApproximation")
  351. {
  352. string value = convertMatlabToString(prhs[i+1]);
  353. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  354. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  355. conf.sS("GPHIKClassifier", variable, value);
  356. }
  357. if(variable == "nrOfEigenvaluesToConsiderForVarApprox")
  358. {
  359. double value = convertMatlabToDouble(prhs[i+1]);
  360. conf.sI("GPHIKClassifier", variable, (int) value);
  361. }
  362. }
  363. return conf;
  364. }
  365. // MAIN MATLAB FUNCTION
  366. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  367. {
  368. // get the command string specifying what to do
  369. if (nrhs < 1)
  370. mexErrMsgTxt("No commands and options passed... Aborting!");
  371. if( !mxIsChar( prhs[0] ) )
  372. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  373. std::string cmd = convertMatlabToString( prhs[0] );
  374. // create object
  375. if ( !strcmp("new", cmd.c_str() ) )
  376. {
  377. // check output variable
  378. if (nlhs != 1)
  379. mexErrMsgTxt("New: One output expected.");
  380. // read config settings
  381. NICE::Config conf = parseParameters(prhs+1,nrhs-1);
  382. // create class instance
  383. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf );
  384. // handle to the C++ instance
  385. plhs[0] = convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  386. return;
  387. }
  388. // in all other cases, there should be a second input,
  389. // which the be the class instance handle
  390. if (nrhs < 2)
  391. mexErrMsgTxt("Second input should be a class instance handle.");
  392. // delete object
  393. if ( !strcmp("delete", cmd.c_str() ) )
  394. {
  395. // Destroy the C++ object
  396. destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  397. return;
  398. }
  399. // get the class instance pointer from the second input
  400. // every following function needs the classifier object
  401. NICE::GPHIKClassifier * classifier = convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  402. ////////////////////////////////////////
  403. // Check which class method to call //
  404. ////////////////////////////////////////
  405. // standard train - assumes initialized object
  406. if (!strcmp("train", cmd.c_str() ))
  407. {
  408. // Check parameters
  409. if (nlhs < 0 || nrhs < 4)
  410. {
  411. mexErrMsgTxt("Train: Unexpected arguments.");
  412. }
  413. //------------- read the data --------------
  414. std::vector< NICE::SparseVector *> examplesTrain;
  415. NICE::Vector yMultiTrain;
  416. if ( mxIsSparse( prhs[2] ) )
  417. {
  418. examplesTrain = convertSparseMatrixToNice( prhs[2] );
  419. }
  420. else
  421. {
  422. NICE::Matrix dataTrain;
  423. dataTrain = convertMatrixToNice(prhs[2]);
  424. //----------------- convert data to sparse data structures ---------
  425. examplesTrain.resize( dataTrain.rows() );
  426. std::vector< NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  427. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  428. {
  429. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  430. }
  431. }
  432. yMultiTrain = convertVectorToNice(prhs[3]);
  433. // std::cerr << " DATA AFTER CONVERSION: \n" << std::endl;
  434. // int lineIdx(0);
  435. // for ( std::vector< NICE::SparseVector *>::const_iterator exTrainIt = examplesTrain.begin();
  436. // exTrainIt != examplesTrain.end(); exTrainIt++, lineIdx++)
  437. // {
  438. // std::cerr << "\n lineIdx: " << lineIdx << std::endl;
  439. // (*exTrainIt)->store( std::cerr );
  440. //
  441. // }
  442. // test assumption
  443. {
  444. if( yMultiTrain.Min() < 0)
  445. mexErrMsgIdAndTxt("mexnice:error","Class labels smaller 0 are not allowed");
  446. }
  447. //----------------- train our classifier -------------
  448. classifier->train ( examplesTrain , yMultiTrain );
  449. //----------------- clean up -------------
  450. for(int i=0;i<examplesTrain.size();i++)
  451. delete examplesTrain[i];
  452. return;
  453. }
  454. // Classify
  455. if ( !strcmp("classify", cmd.c_str() ) )
  456. {
  457. // Check parameters
  458. if ( (nlhs < 0) || (nrhs < 2) )
  459. {
  460. mexErrMsgTxt("Test: Unexpected arguments.");
  461. }
  462. //------------- read the data --------------
  463. int result;
  464. NICE::SparseVector scores;
  465. double uncertainty;
  466. if ( mxIsSparse( prhs[2] ) )
  467. {
  468. NICE::SparseVector * example;
  469. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  470. classifier->classify ( example, result, scores, uncertainty );
  471. //----------------- clean up -------------
  472. delete example;
  473. }
  474. else
  475. {
  476. NICE::Vector * example;
  477. example = new NICE::Vector ( convertVectorToNice(prhs[2]) );
  478. classifier->classify ( example, result, scores, uncertainty );
  479. //----------------- clean up -------------
  480. delete example;
  481. }
  482. // output
  483. plhs[0] = mxCreateDoubleScalar( result );
  484. if(nlhs >= 2)
  485. {
  486. plhs[1] = convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  487. }
  488. if(nlhs >= 3)
  489. {
  490. plhs[2] = mxCreateDoubleScalar( uncertainty );
  491. }
  492. return;
  493. }
  494. // Classify
  495. if ( !strcmp("uncertainty", cmd.c_str() ) )
  496. {
  497. // Check parameters
  498. if ( (nlhs < 0) || (nrhs < 2) )
  499. {
  500. mexErrMsgTxt("Test: Unexpected arguments.");
  501. }
  502. double uncertainty;
  503. //------------- read the data --------------
  504. if ( mxIsSparse( prhs[2] ) )
  505. {
  506. NICE::SparseVector * example;
  507. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  508. classifier->predictUncertainty( example, uncertainty );
  509. //----------------- clean up -------------
  510. delete example;
  511. }
  512. else
  513. {
  514. NICE::Vector * example;
  515. example = new NICE::Vector ( convertVectorToNice(prhs[2]) );
  516. classifier->predictUncertainty( example, uncertainty );
  517. //----------------- clean up -------------
  518. delete example;
  519. }
  520. // output
  521. plhs[0] = mxCreateDoubleScalar( uncertainty );
  522. return;
  523. }
  524. // Test
  525. if ( !strcmp("test", cmd.c_str() ) )
  526. {
  527. // Check parameters
  528. if (nlhs < 0 || nrhs < 4)
  529. mexErrMsgTxt("Test: Unexpected arguments.");
  530. //------------- read the data --------------
  531. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  532. std::vector< NICE::SparseVector *> dataTest_sparse;
  533. NICE::Matrix dataTest_dense;
  534. if ( dataIsSparse )
  535. {
  536. dataTest_sparse = convertSparseMatrixToNice( prhs[2] );
  537. }
  538. else
  539. {
  540. dataTest_dense = convertMatrixToNice(prhs[2]);
  541. }
  542. NICE::Vector yMultiTest;
  543. yMultiTest = convertVectorToNice(prhs[3]);
  544. // ------------------------------------------
  545. // ------------- PREPARATION --------------
  546. // ------------------------------------------
  547. // determine classes known during training and corresponding mapping
  548. // thereby allow for non-continous class labels
  549. std::set<int> classesKnownTraining = classifier->getKnownClassNumbers();
  550. int noClassesKnownTraining ( classesKnownTraining.size() );
  551. std::map<int,int> mapClNoToIdxTrain;
  552. std::set<int>::const_iterator clTrIt = classesKnownTraining.begin();
  553. for ( int i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  554. mapClNoToIdxTrain.insert ( std::pair<int,int> ( *clTrIt, i ) );
  555. // determine classes known during testing and corresponding mapping
  556. // thereby allow for non-continous class labels
  557. std::set<int> classesKnownTest;
  558. classesKnownTest.clear();
  559. // determine which classes we have in our label vector
  560. // -> MATLAB: myClasses = unique(y);
  561. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  562. {
  563. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  564. {
  565. classesKnownTest.insert ( *it );
  566. }
  567. }
  568. int noClassesKnownTest ( classesKnownTest.size() );
  569. std::map<int,int> mapClNoToIdxTest;
  570. std::set<int>::const_iterator clTestIt = classesKnownTest.begin();
  571. for ( int i=0; i < noClassesKnownTest; i++, clTestIt++ )
  572. mapClNoToIdxTest.insert ( std::pair<int,int> ( *clTestIt, i ) );
  573. int i_numTestSamples;
  574. if ( dataIsSparse )
  575. i_numTestSamples = dataTest_sparse.size();
  576. else
  577. i_numTestSamples = (int) dataTest_dense.rows();
  578. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  579. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  580. // ------------------------------------------
  581. // ------------- CLASSIFICATION --------------
  582. // ------------------------------------------
  583. NICE::Timer t;
  584. double testTime (0.0);
  585. for (int i = 0; i < i_numTestSamples; i++)
  586. {
  587. //----------------- convert data to sparse data structures ---------
  588. int result;
  589. NICE::SparseVector exampleScoresSparse;
  590. if ( dataIsSparse )
  591. {
  592. // and classify
  593. t.start();
  594. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  595. t.stop();
  596. testTime += t.getLast();
  597. }
  598. else
  599. {
  600. NICE::Vector example ( dataTest_dense.getRow(i) );
  601. // and classify
  602. t.start();
  603. classifier->classify( &example, result, exampleScoresSparse );
  604. t.stop();
  605. testTime += t.getLast();
  606. }
  607. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  608. int scoreCnt ( 0 );
  609. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  610. scores(i,scoreCnt) = scoreIt->second;
  611. }
  612. std::cerr << "Time for testing: " << testTime << std::endl;
  613. // clean up
  614. if ( dataIsSparse )
  615. {
  616. for ( std::vector<NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  617. delete *it;
  618. }
  619. confusionMatrix.normalizeColumnsL1();
  620. //std::cerr << confusionMatrix << std::endl;
  621. double recRate = confusionMatrix.trace()/confusionMatrix.rows();
  622. //std::cerr << "average recognition rate: " << recRate << std::endl;
  623. plhs[0] = mxCreateDoubleScalar( recRate );
  624. if(nlhs >= 2)
  625. plhs[1] = convertMatrixFromNice(confusionMatrix);
  626. if(nlhs >= 3)
  627. plhs[2] = convertMatrixFromNice(scores);
  628. return;
  629. }
  630. // store the classifier
  631. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  632. {
  633. // Check parameters
  634. if ( nrhs < 3 )
  635. mexErrMsgTxt("store: no destination given.");
  636. std::string s_destination = convertMatlabToString( prhs[2] );
  637. std::filebuf fb;
  638. fb.open ( s_destination.c_str(), ios::out );
  639. std::ostream os(&fb);
  640. //
  641. classifier->store( os );
  642. //
  643. fb.close();
  644. return;
  645. }
  646. // load classifier from external file
  647. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  648. {
  649. // Check parameters
  650. if ( nrhs < 3 )
  651. mexErrMsgTxt("restore: no destination given.");
  652. std::string s_destination = convertMatlabToString( prhs[2] );
  653. std::cerr << " aim at restoring the classifier from " << s_destination << std::endl;
  654. std::filebuf fbIn;
  655. fbIn.open ( s_destination.c_str(), ios::in );
  656. std::istream is (&fbIn);
  657. //
  658. classifier->restore( is );
  659. //
  660. fbIn.close();
  661. return;
  662. }
  663. // Got here, so command not recognized
  664. std::string errorMsg (cmd.c_str() );
  665. errorMsg += " -- command not recognized.";
  666. mexErrMsgTxt( errorMsg.c_str() );
  667. }