GPHIK.cpp 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848
  1. #include <math.h>
  2. #include <matrix.h>
  3. #include "mex.h"
  4. #include "classHandleMtoC.h"
  5. // NICE-core includes
  6. #include <core/basics/Config.h>
  7. #include <core/basics/Timer.h>
  8. #include <core/vector/MatrixT.h>
  9. #include <core/vector/VectorT.h>
  10. // gp-hik-core includes
  11. #include "gp-hik-core/GPHIKClassifier.h"
  12. using namespace std; //C basics
  13. using namespace NICE; // nice-core
  14. /* Pass analyze_sparse a pointer to a sparse mxArray. A sparse mxArray
  15. only stores its nonzero elements. The values of the nonzero elements
  16. are stored in the pr and pi arrays. The tricky part of analyzing
  17. sparse mxArray's is figuring out the indices where the nonzero
  18. elements are stored. (See the mxSetIr and mxSetJc reference pages
  19. for details. */
  20. std::vector< NICE::SparseVector * > convertSparseMatrixToNice(const mxArray *array_ptr)
  21. {
  22. double *pr;//, *pi;
  23. mwIndex *ir, *jc;
  24. mwSize col, total=0;
  25. mwIndex starting_row_index, stopping_row_index, current_row_index;
  26. mwSize i_numExamples, i_numDim;
  27. /* Get the starting positions of all four data arrays. */
  28. pr = mxGetPr(array_ptr);
  29. // pi = mxGetPi(array_ptr);
  30. ir = mxGetIr(array_ptr);
  31. jc = mxGetJc(array_ptr);
  32. // dimenions of the matrix -> feature dimension and number of examples
  33. i_numExamples = mxGetM(array_ptr);
  34. i_numDim = mxGetN(array_ptr);
  35. // initialize output variable
  36. std::vector< NICE::SparseVector * > sparseMatrix;
  37. sparseMatrix.resize ( i_numExamples );
  38. for ( std::vector< NICE::SparseVector * >::iterator matIt = sparseMatrix.begin();
  39. matIt != sparseMatrix.end(); matIt++)
  40. {
  41. *matIt = new NICE::SparseVector( i_numDim );
  42. }
  43. // now copy the data
  44. for (col=0; col < i_numDim; col++)
  45. {
  46. starting_row_index = jc[col];
  47. stopping_row_index = jc[col+1];
  48. // empty column?
  49. if (starting_row_index == stopping_row_index)
  50. continue;
  51. else
  52. {
  53. for ( current_row_index = starting_row_index;
  54. current_row_index < stopping_row_index;
  55. current_row_index++)
  56. {
  57. //note: no complex data supported her
  58. sparseMatrix[ ir[current_row_index] ]->insert( std::pair<int, double>( col, pr[total++] ) );
  59. } // for-loop
  60. }
  61. } // for-loop over columns
  62. return sparseMatrix;
  63. }
  64. // b_adaptIndexMtoC: if true, dim k will be inserted as k, not as k-1 (which would be the default for M->C)
  65. NICE::SparseVector convertSparseVectorToNice(const mxArray* array_ptr, const bool & b_adaptIndexMtoC = false )
  66. {
  67. double *pr, *pi;
  68. mwIndex *ir, *jc;
  69. mwSize col, total=0;
  70. mwIndex starting_row_index, stopping_row_index, current_row_index;
  71. mwSize dimy, dimx;
  72. /* Get the starting positions of all four data arrays. */
  73. pr = mxGetPr(array_ptr);
  74. pi = mxGetPi(array_ptr);
  75. ir = mxGetIr(array_ptr);
  76. jc = mxGetJc(array_ptr);
  77. // dimenions of the matrix -> feature dimension and number of examples
  78. dimy = mxGetM(array_ptr);
  79. dimx = mxGetN(array_ptr);
  80. double* ptr = mxGetPr(array_ptr);
  81. if(dimx != 1 && dimy != 1)
  82. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  83. NICE::SparseVector svec( std::max(dimx, dimy) );
  84. if ( dimx > 1)
  85. {
  86. for ( mwSize row=0; row < dimx; row++)
  87. {
  88. // empty column?
  89. if (jc[row] == jc[row+1])
  90. {
  91. continue;
  92. }
  93. else
  94. {
  95. //note: no complex data supported her
  96. double value ( pr[total++] );
  97. if ( b_adaptIndexMtoC )
  98. svec.insert( std::pair<int, double>( row+1, value ) );
  99. else
  100. svec.insert( std::pair<int, double>( row, value ) );
  101. }
  102. } // for loop over cols
  103. }
  104. else
  105. {
  106. mwSize numNonZero = jc[1]-jc[0];
  107. for ( mwSize colNonZero=0; colNonZero < numNonZero; colNonZero++)
  108. {
  109. //note: no complex data supported her
  110. double value ( pr[total++] );
  111. if ( b_adaptIndexMtoC )
  112. svec.insert( std::pair<int, double>( ir[colNonZero]+1, value ) );
  113. else
  114. svec.insert( std::pair<int, double>( ir[colNonZero], value ) );
  115. }
  116. }
  117. return svec;
  118. }
  119. // b_adaptIndexCtoM: if true, dim k will be inserted as k, not as k+1 (which would be the default for C->M)
  120. mxArray* convertSparseVectorFromNice( const NICE::SparseVector & scores, const bool & b_adaptIndexCtoM = false)
  121. {
  122. mxArray * matlabSparseVec = mxCreateSparse( scores.getDim() /*m*/, 1/*n*/, scores.size()/*nzmax*/, mxREAL);
  123. // To make the returned sparse mxArray useful, you must initialize the pr, ir, jc, and (if it exists) pi arrays.
  124. // mxCreateSparse allocates space for:
  125. //
  126. // A pr array of length nzmax.
  127. // A pi array of length nzmax, but only if ComplexFlag is mxCOMPLEX in C (1 in Fortran).
  128. // An ir array of length nzmax.
  129. // A jc array of length n+1.
  130. double* prPtr = mxGetPr(matlabSparseVec);
  131. mwIndex * ir = mxGetIr( matlabSparseVec );
  132. mwIndex * jc = mxGetJc( matlabSparseVec );
  133. jc[1] = scores.size(); jc[0] = 0;
  134. mwSize cnt = 0;
  135. for ( NICE::SparseVector::const_iterator myIt = scores.begin(); myIt != scores.end(); myIt++, cnt++ )
  136. {
  137. // set index
  138. if ( b_adaptIndexCtoM )
  139. ir[cnt] = myIt->first-1;
  140. else
  141. ir[cnt] = myIt->first;
  142. // set value
  143. prPtr[cnt] = myIt->second;
  144. }
  145. return matlabSparseVec;
  146. }
  147. mxArray* convertMatrixFromNice(NICE::Matrix & niceMatrix)
  148. {
  149. mxArray *matlabMatrix = mxCreateDoubleMatrix(niceMatrix.rows(),niceMatrix.cols(),mxREAL);
  150. double* matlabMatrixPtr = mxGetPr(matlabMatrix);
  151. for(int i=0; i<niceMatrix.rows(); i++)
  152. {
  153. for(int j=0; j<niceMatrix.cols(); j++)
  154. {
  155. matlabMatrixPtr[i + j*niceMatrix.rows()] = niceMatrix(i,j);
  156. }
  157. }
  158. return matlabMatrix;
  159. }
  160. NICE::Matrix convertMatrixToNice(const mxArray* matlabMatrix)
  161. {
  162. //todo: do not assume double
  163. const mwSize *dims;
  164. int dimx, dimy, numdims;
  165. //figure out dimensions
  166. dims = mxGetDimensions(matlabMatrix);
  167. numdims = mxGetNumberOfDimensions(matlabMatrix);
  168. dimy = (int)dims[0]; dimx = (int)dims[1];
  169. double* ptr = mxGetPr(matlabMatrix);
  170. NICE::Matrix niceMatrix(ptr, dimy, dimx, NICE::Matrix::external);
  171. return niceMatrix;
  172. }
  173. mxArray* convertVectorFromNice(NICE::Vector & niceVector)
  174. {
  175. //cout << "start convertVectorFromNice" << endl;
  176. mxArray *matlabVector = mxCreateDoubleMatrix(niceVector.size(), 1, mxREAL);
  177. double* matlabVectorPtr = mxGetPr(matlabVector);
  178. for(int i=0;i<niceVector.size(); i++)
  179. {
  180. matlabVectorPtr[i] = niceVector[i];
  181. }
  182. return matlabVector;
  183. }
  184. NICE::Vector convertVectorToNice(const mxArray* matlabMatrix)
  185. {
  186. //todo: do not assume double
  187. const mwSize *dims;
  188. int dimx, dimy, numdims;
  189. //figure out dimensions
  190. dims = mxGetDimensions(matlabMatrix);
  191. numdims = mxGetNumberOfDimensions(matlabMatrix);
  192. dimy = (int)dims[0]; dimx = (int)dims[1];
  193. double* ptr = mxGetPr(matlabMatrix);
  194. if(dimx != 1 && dimy != 1)
  195. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  196. int dim = max(dimx, dimy);
  197. NICE::Vector niceVector(dim, 0.0);
  198. for(int i=0;i<dim;i++)
  199. {
  200. niceVector(i) = ptr[i];
  201. }
  202. return niceVector;
  203. }
  204. std::string convertMatlabToString(const mxArray *matlabString)
  205. {
  206. if(!mxIsChar(matlabString))
  207. mexErrMsgIdAndTxt("mexnice:error","Expected string");
  208. char *cstring = mxArrayToString(matlabString);
  209. std::string s(cstring);
  210. mxFree(cstring);
  211. return s;
  212. }
  213. int convertMatlabToInt32(const mxArray *matlabInt32)
  214. {
  215. if(!mxIsInt32(matlabInt32))
  216. mexErrMsgIdAndTxt("mexnice:error","Expected int32");
  217. int* ptr = (int*)mxGetData(matlabInt32);
  218. return ptr[0];
  219. }
  220. double convertMatlabToDouble(const mxArray *matlabDouble)
  221. {
  222. if(!mxIsDouble(matlabDouble))
  223. mexErrMsgIdAndTxt("mexnice:error","Expected double");
  224. double* ptr = (double*)mxGetData(matlabDouble);
  225. return ptr[0];
  226. }
  227. Config parseParameters(const mxArray *prhs[], int nrhs)
  228. {
  229. Config conf;
  230. for(int i=0;i<nrhs;i+=2)
  231. {
  232. string variable = convertMatlabToString(prhs[i]);
  233. if(variable == "ils_verbose")
  234. {
  235. string value = convertMatlabToString(prhs[i+1]);
  236. if(value != "true" && value != "false")
  237. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_verbose\'. \'true\' or \'false\' expected.");
  238. if(value == "true")
  239. conf.sB("GPHIKClassifier", variable, true);
  240. else
  241. conf.sB("GPHIKClassifier", variable, false);
  242. }
  243. if(variable == "ils_max_iterations")
  244. {
  245. int value = convertMatlabToInt32(prhs[i+1]);
  246. if(value < 1)
  247. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_max_iterations\'.");
  248. conf.sI("GPHIKClassifier", variable, value);
  249. }
  250. if(variable == "ils_method")
  251. {
  252. string value = convertMatlabToString(prhs[i+1]);
  253. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  254. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  255. conf.sS("GPHIKClassifier", variable, value);
  256. }
  257. if(variable == "ils_min_delta")
  258. {
  259. double value = convertMatlabToDouble(prhs[i+1]);
  260. if(value < 0.0)
  261. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_min_delta\'.");
  262. conf.sD("GPHIKClassifier", variable, value);
  263. }
  264. if(variable == "ils_min_residual")
  265. {
  266. double value = convertMatlabToDouble(prhs[i+1]);
  267. if(value < 0.0)
  268. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_min_residual\'.");
  269. conf.sD("GPHIKClassifier", variable, value);
  270. }
  271. if(variable == "optimization_method")
  272. {
  273. string value = convertMatlabToString(prhs[i+1]);
  274. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  275. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  276. conf.sS("GPHIKClassifier", variable, value);
  277. }
  278. if(variable == "use_quantization")
  279. {
  280. string value = convertMatlabToString(prhs[i+1]);
  281. if(value != "true" && value != "false")
  282. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'use_quantization\'. \'true\' or \'false\' expected.");
  283. if(value == "true")
  284. conf.sB("GPHIKClassifier", variable, true);
  285. else
  286. conf.sB("GPHIKClassifier", variable, false);
  287. }
  288. if(variable == "num_bins")
  289. {
  290. int value = convertMatlabToInt32(prhs[i+1]);
  291. if(value < 1)
  292. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'num_bins\'.");
  293. conf.sI("GPHIKClassifier", variable, value);
  294. }
  295. if(variable == "transform")
  296. {
  297. string value = convertMatlabToString(prhs[i+1]);
  298. if(value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  299. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  300. conf.sS("GPHIKClassifier", variable, value);
  301. }
  302. if(variable == "verboseTime")
  303. {
  304. string value = convertMatlabToString(prhs[i+1]);
  305. if(value != "true" && value != "false")
  306. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'verboseTime\'. \'true\' or \'false\' expected.");
  307. if(value == "true")
  308. conf.sB("GPHIKClassifier", variable, true);
  309. else
  310. conf.sB("GPHIKClassifier", variable, false);
  311. }
  312. if(variable == "verbose")
  313. {
  314. string value = convertMatlabToString(prhs[i+1]);
  315. if(value != "true" && value != "false")
  316. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'verbose\'. \'true\' or \'false\' expected.");
  317. if(value == "true")
  318. conf.sB("GPHIKClassifier", variable, true);
  319. else
  320. conf.sB("GPHIKClassifier", variable, false);
  321. }
  322. if(variable == "noise")
  323. {
  324. double value = convertMatlabToDouble(prhs[i+1]);
  325. if(value < 0.0)
  326. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value larger than 0 for \'noise\'.");
  327. conf.sD("GPHIKClassifier", variable, value);
  328. }
  329. if(variable == "learn_balanced")
  330. {
  331. string value = convertMatlabToString(prhs[i+1]);
  332. if(value != "true" && value != "false")
  333. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'learn_balanced\'. \'true\' or \'false\' expected.");
  334. if(value == "true")
  335. conf.sB("GPHIKClassifier", variable, true);
  336. else
  337. conf.sB("GPHIKClassifier", variable, false);
  338. }
  339. if(variable == "optimize_noise")
  340. {
  341. string value = convertMatlabToString(prhs[i+1]);
  342. if(value != "true" && value != "false")
  343. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimize_noise\'. \'true\' or \'false\' expected.");
  344. if(value == "true")
  345. conf.sB("GPHIKClassifier", variable, true);
  346. else
  347. conf.sB("GPHIKClassifier", variable, false);
  348. }
  349. if(variable == "varianceApproximation")
  350. {
  351. string value = convertMatlabToString(prhs[i+1]);
  352. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  353. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  354. conf.sS("GPHIKClassifier", variable, value);
  355. }
  356. if(variable == "nrOfEigenvaluesToConsiderForVarApprox")
  357. {
  358. double value = convertMatlabToDouble(prhs[i+1]);
  359. conf.sI("GPHIKClassifier", variable, (int) value);
  360. }
  361. }
  362. return conf;
  363. }
  364. // MAIN MATLAB FUNCTION
  365. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  366. {
  367. // get the command string specifying what to do
  368. if (nrhs < 1)
  369. mexErrMsgTxt("No commands and options passed... Aborting!");
  370. if( !mxIsChar( prhs[0] ) )
  371. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  372. std::string cmd = convertMatlabToString( prhs[0] );
  373. // create object
  374. if ( !strcmp("new", cmd.c_str() ) )
  375. {
  376. // check output variable
  377. if (nlhs != 1)
  378. mexErrMsgTxt("New: One output expected.");
  379. // read config settings
  380. NICE::Config conf = parseParameters(prhs+1,nrhs-1);
  381. // create class instance
  382. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf );
  383. // handle to the C++ instance
  384. plhs[0] = convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  385. return;
  386. }
  387. // in all other cases, there should be a second input,
  388. // which the be the class instance handle
  389. if (nrhs < 2)
  390. mexErrMsgTxt("Second input should be a class instance handle.");
  391. // delete object
  392. if ( !strcmp("delete", cmd.c_str() ) )
  393. {
  394. // Destroy the C++ object
  395. destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  396. return;
  397. }
  398. // get the class instance pointer from the second input
  399. // every following function needs the classifier object
  400. NICE::GPHIKClassifier * classifier = convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  401. ////////////////////////////////////////
  402. // Check which class method to call //
  403. ////////////////////////////////////////
  404. // standard train - assumes initialized object
  405. if (!strcmp("train", cmd.c_str() ))
  406. {
  407. // Check parameters
  408. if (nlhs < 0 || nrhs < 4)
  409. {
  410. mexErrMsgTxt("Train: Unexpected arguments.");
  411. }
  412. //------------- read the data --------------
  413. std::vector< NICE::SparseVector *> examplesTrain;
  414. NICE::Vector yMultiTrain;
  415. if ( mxIsSparse( prhs[2] ) )
  416. {
  417. examplesTrain = convertSparseMatrixToNice( prhs[2] );
  418. }
  419. else
  420. {
  421. NICE::Matrix dataTrain;
  422. dataTrain = convertMatrixToNice(prhs[2]);
  423. //----------------- convert data to sparse data structures ---------
  424. examplesTrain.resize( dataTrain.rows() );
  425. std::vector< NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  426. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  427. {
  428. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  429. }
  430. }
  431. yMultiTrain = convertVectorToNice(prhs[3]);
  432. // std::cerr << " DATA AFTER CONVERSION: \n" << std::endl;
  433. // int lineIdx(0);
  434. // for ( std::vector< NICE::SparseVector *>::const_iterator exTrainIt = examplesTrain.begin();
  435. // exTrainIt != examplesTrain.end(); exTrainIt++, lineIdx++)
  436. // {
  437. // std::cerr << "\n lineIdx: " << lineIdx << std::endl;
  438. // (*exTrainIt)->store( std::cerr );
  439. //
  440. // }
  441. // test assumption
  442. {
  443. if( yMultiTrain.Min() < 0)
  444. mexErrMsgIdAndTxt("mexnice:error","Class labels smaller 0 are not allowed");
  445. }
  446. //----------------- train our classifier -------------
  447. classifier->train ( examplesTrain , yMultiTrain );
  448. //----------------- clean up -------------
  449. for(int i=0;i<examplesTrain.size();i++)
  450. delete examplesTrain[i];
  451. return;
  452. }
  453. // Classify
  454. if ( !strcmp("classify", cmd.c_str() ) )
  455. {
  456. // Check parameters
  457. if ( (nlhs < 0) || (nrhs < 2) )
  458. {
  459. mexErrMsgTxt("Test: Unexpected arguments.");
  460. }
  461. //------------- read the data --------------
  462. int result;
  463. NICE::SparseVector scores;
  464. double uncertainty;
  465. if ( mxIsSparse( prhs[2] ) )
  466. {
  467. NICE::SparseVector * example;
  468. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  469. classifier->classify ( example, result, scores, uncertainty );
  470. //----------------- clean up -------------
  471. delete example;
  472. }
  473. else
  474. {
  475. NICE::Vector * example;
  476. example = new NICE::Vector ( convertVectorToNice(prhs[2]) );
  477. classifier->classify ( example, result, scores, uncertainty );
  478. //----------------- clean up -------------
  479. delete example;
  480. }
  481. // output
  482. plhs[0] = mxCreateDoubleScalar( result );
  483. if(nlhs >= 2)
  484. {
  485. plhs[1] = convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  486. }
  487. if(nlhs >= 3)
  488. {
  489. plhs[2] = mxCreateDoubleScalar( uncertainty );
  490. }
  491. return;
  492. }
  493. // Classify
  494. if ( !strcmp("uncertainty", cmd.c_str() ) )
  495. {
  496. // Check parameters
  497. if ( (nlhs < 0) || (nrhs < 2) )
  498. {
  499. mexErrMsgTxt("Test: Unexpected arguments.");
  500. }
  501. double uncertainty;
  502. //------------- read the data --------------
  503. if ( mxIsSparse( prhs[2] ) )
  504. {
  505. NICE::SparseVector * example;
  506. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  507. classifier->predictUncertainty( example, uncertainty );
  508. //----------------- clean up -------------
  509. delete example;
  510. }
  511. else
  512. {
  513. NICE::Vector * example;
  514. example = new NICE::Vector ( convertVectorToNice(prhs[2]) );
  515. classifier->predictUncertainty( example, uncertainty );
  516. //----------------- clean up -------------
  517. delete example;
  518. }
  519. // output
  520. plhs[0] = mxCreateDoubleScalar( uncertainty );
  521. return;
  522. }
  523. // Test
  524. if ( !strcmp("test", cmd.c_str() ) )
  525. {
  526. // Check parameters
  527. if (nlhs < 0 || nrhs < 4)
  528. mexErrMsgTxt("Test: Unexpected arguments.");
  529. //------------- read the data --------------
  530. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  531. std::vector< NICE::SparseVector *> dataTest_sparse;
  532. NICE::Matrix dataTest_dense;
  533. if ( dataIsSparse )
  534. {
  535. dataTest_sparse = convertSparseMatrixToNice( prhs[2] );
  536. }
  537. else
  538. {
  539. dataTest_dense = convertMatrixToNice(prhs[2]);
  540. }
  541. NICE::Vector yMultiTest;
  542. yMultiTest = convertVectorToNice(prhs[3]);
  543. // ------------------------------------------
  544. // ------------- PREPARATION --------------
  545. // ------------------------------------------
  546. // determine classes known during training and corresponding mapping
  547. // thereby allow for non-continous class labels
  548. std::set<int> classesKnownTraining = classifier->getKnownClassNumbers();
  549. int noClassesKnownTraining ( classesKnownTraining.size() );
  550. std::map<int,int> mapClNoToIdxTrain;
  551. std::set<int>::const_iterator clTrIt = classesKnownTraining.begin();
  552. for ( int i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  553. mapClNoToIdxTrain.insert ( std::pair<int,int> ( *clTrIt, i ) );
  554. // determine classes known during testing and corresponding mapping
  555. // thereby allow for non-continous class labels
  556. std::set<int> classesKnownTest;
  557. classesKnownTest.clear();
  558. // determine which classes we have in our label vector
  559. // -> MATLAB: myClasses = unique(y);
  560. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  561. {
  562. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  563. {
  564. classesKnownTest.insert ( *it );
  565. }
  566. }
  567. int noClassesKnownTest ( classesKnownTest.size() );
  568. std::map<int,int> mapClNoToIdxTest;
  569. std::set<int>::const_iterator clTestIt = classesKnownTest.begin();
  570. for ( int i=0; i < noClassesKnownTest; i++, clTestIt++ )
  571. mapClNoToIdxTest.insert ( std::pair<int,int> ( *clTestIt, i ) );
  572. int i_numTestSamples;
  573. if ( dataIsSparse )
  574. i_numTestSamples = dataTest_sparse.size();
  575. else
  576. i_numTestSamples = (int) dataTest_dense.rows();
  577. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  578. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  579. // ------------------------------------------
  580. // ------------- CLASSIFICATION --------------
  581. // ------------------------------------------
  582. NICE::Timer t;
  583. double testTime (0.0);
  584. for (int i = 0; i < i_numTestSamples; i++)
  585. {
  586. //----------------- convert data to sparse data structures ---------
  587. int result;
  588. NICE::SparseVector exampleScoresSparse;
  589. if ( dataIsSparse )
  590. {
  591. // and classify
  592. t.start();
  593. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  594. t.stop();
  595. testTime += t.getLast();
  596. }
  597. else
  598. {
  599. NICE::Vector example ( dataTest_dense.getRow(i) );
  600. // and classify
  601. t.start();
  602. classifier->classify( &example, result, exampleScoresSparse );
  603. t.stop();
  604. testTime += t.getLast();
  605. }
  606. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  607. int scoreCnt ( 0 );
  608. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  609. scores(i,scoreCnt) = scoreIt->second;
  610. }
  611. std::cerr << "Time for testing: " << testTime << std::endl;
  612. // clean up
  613. if ( dataIsSparse )
  614. {
  615. for ( std::vector<NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  616. delete *it;
  617. }
  618. confusionMatrix.normalizeColumnsL1();
  619. //std::cerr << confusionMatrix << std::endl;
  620. double recRate = confusionMatrix.trace()/confusionMatrix.rows();
  621. //std::cerr << "average recognition rate: " << recRate << std::endl;
  622. plhs[0] = mxCreateDoubleScalar( recRate );
  623. if(nlhs >= 2)
  624. plhs[1] = convertMatrixFromNice(confusionMatrix);
  625. if(nlhs >= 3)
  626. plhs[2] = convertMatrixFromNice(scores);
  627. return;
  628. }
  629. // store the classifier
  630. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  631. {
  632. // Check parameters
  633. if ( nrhs < 3 )
  634. mexErrMsgTxt("store: no destination given.");
  635. std::string s_destination = convertMatlabToString( prhs[2] );
  636. std::filebuf fb;
  637. fb.open ( s_destination.c_str(), ios::out );
  638. std::ostream os(&fb);
  639. //
  640. classifier->store( os );
  641. //
  642. fb.close();
  643. return;
  644. }
  645. // load classifier from external file
  646. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  647. {
  648. // Check parameters
  649. if ( nrhs < 3 )
  650. mexErrMsgTxt("restore: no destination given.");
  651. std::string s_destination = convertMatlabToString( prhs[2] );
  652. std::filebuf fbIn;
  653. fbIn.open ( s_destination.c_str(), ios::in );
  654. std::istream is (&fbIn);
  655. //
  656. classifier->restore( is );
  657. //
  658. fbIn.close();
  659. return;
  660. }
  661. // Got here, so command not recognized
  662. std::string errorMsg (cmd.c_str() );
  663. errorMsg += " -- command not recognized.";
  664. mexErrMsgTxt( errorMsg.c_str() );
  665. }