liblinear_train.m 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. function svmmodel = liblinear_train ( labels, feat, settings )
  2. %
  3. % BRIEF
  4. % A simple wrapper to provide training of 1-vs-all-classification for LIBLINEAR. No
  5. % further settings are adjustable currently.
  6. %
  7. % INPUT
  8. % labels -- multi-class labels (#sample x 1)
  9. % feat -- features for training images (#samples x # dimensions)
  10. % settings -- struct for configuring the svm model training, e.g., via
  11. % 'b_verbose', 'f_svm_C', ...
  12. %
  13. % OUTPUT:
  14. % svmmodel -- cell ( #classes x 1 ), every model entry is obtained via
  15. % svmtrain of the corresponding 1-vs-all-problem
  16. %
  17. % date: 30-04-2014 ( dd-mm-yyyy )
  18. % last modified: 22-10-2015
  19. % author: Alexander Freytag, Christoph Käding
  20. if ( nargin < 3 )
  21. settings = [];
  22. end
  23. libsvm_options = '';
  24. % outputs for training
  25. if ( ~ getFieldWithDefault ( settings, 'b_verbose', false ) )
  26. libsvm_options = sprintf('%s -q', libsvm_options);
  27. end
  28. % cost parameter
  29. f_svm_C = getFieldWithDefault ( settings, 'f_svm_C', 1);
  30. libsvm_options = sprintf('%s -c %f', libsvm_options, f_svm_C);
  31. % do we want to use an offset for the hyperplane?
  32. if ( getFieldWithDefault ( settings, 'b_addOffset', false) )
  33. libsvm_options = sprintf('%s -B 1', libsvm_options);
  34. end
  35. % add multithreading
  36. % NOTE: - requires liblinear-multicore
  37. % - supports only -s 0, -s 2, or -s 11 (so far)
  38. i_numThreads = getFieldWithDefault ( settings, 'i_numThreads', 1);
  39. if i_numThreads > 1
  40. libsvm_options = sprintf('%s -n %d', libsvm_options, i_numThreads);
  41. end
  42. % which solver to use
  43. % copied from the liblinear manual:
  44. % for multi-class classification
  45. % 0 -- L2-regularized logistic regression (primal)
  46. % 1 -- L2-regularized L2-loss support vector classification (dual)
  47. % 2 -- L2-regularized L2-loss support vector classification (primal)
  48. % 3 -- L2-regularized L1-loss support vector classification (dual)
  49. % 4 -- support vector classification by Crammer and Singer
  50. % 5 -- L1-regularized L2-loss support vector classification
  51. % 6 -- L1-regularized logistic regression
  52. % 7 -- L2-regularized logistic regression (dual)
  53. i_svmSolver = getFieldWithDefault ( settings, 'i_svmSolver', 1);
  54. libsvm_options = sprintf('%s -s %d', libsvm_options, i_svmSolver);
  55. % increase penalty for positive samples according to invers ratio of
  56. % their number, i.e., if 1/3 is ratio of positive to negative samples, then
  57. % impact of positives is 3 the times of negatives
  58. %
  59. b_weightBalancing = getFieldWithDefault ( settings, 'b_weightBalancing', false);
  60. uniqueLabels = unique ( labels );
  61. i_numClasses = size ( uniqueLabels,1);
  62. %# train one-against-all models
  63. if ( ~b_weightBalancing)
  64. svmmodel = train( labels, feat, libsvm_options );
  65. else
  66. svmmodel = cell( i_numClasses,1);
  67. for k=1:i_numClasses
  68. yBin = 2*double( labels == uniqueLabels( k ) )-1;
  69. fraction = double(sum(yBin==1))/double(numel(yBin));
  70. libsvm_optionsLocal = sprintf('%s -w1 %f', libsvm_options, 1.0/fraction);
  71. svmmodel{ k } = train( yBin, feat, libsvm_optionsLocal );
  72. %store the unique class label for later evaluations.
  73. svmmodel{ k }.uniqueLabel = uniqueLabels( k );
  74. end
  75. end
  76. end