evaluateIncrementalLearning.m 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. function evaluateIncrementalLearning ( settings )
  2. % compare multi-class incremental learning with Gaussian processes using
  3. % either efficient model updates or batch training from scratch
  4. %
  5. % Copyright (c) by Alexander Freytag, 2013-11-13.
  6. if ( (nargin < 1) || isempty ( settings) )
  7. settings = [];
  8. end
  9. %setup variables for experiments
  10. if ( ~isfield( settings, 'i_numDim') || isempty(settings.i_numDim) )
  11. settings.i_numDim = 10;
  12. end
  13. if ( ~isfield( settings, 'd_mean') || isempty(settings.d_mean) )
  14. settings.d_mean = 3.0;
  15. end
  16. if ( ~isfield( settings, 'd_var') || isempty(settings.d_var) )
  17. settings.d_var = 1.0;
  18. end
  19. if ( ~isfield( settings, 'offsetSpecific') || isempty(settings.offsetSpecific) )
  20. settings.offsetSpecific = [ 0.0, 2.0, 4.0 ];
  21. end
  22. if ( ~isfield( settings, 'varSpecific') || isempty(settings.varSpecific) )
  23. settings.varSpecific = [2.0, 0.5, 1.0];
  24. end
  25. if ( ~isfield( settings, 'dimToChange') || isempty(settings.dimToChange) )
  26. settings.dimToChange = [ [1,3] , [2,4] , [5] ];
  27. end
  28. if ( ~isfield( settings, 'i_numClasses') || isempty(settings.i_numClasses) )
  29. settings.i_numClasses = 3;
  30. end
  31. if ( ~isfield( settings, 'i_numSamplesPerClassTrain') || isempty(settings.i_numSamplesPerClassTrain) )
  32. settings.i_numSamplesPerClassTrain = 10;
  33. end
  34. if ( ~isfield( settings, 'i_numSamplesPerClassUnlabeled') || isempty(settings.i_numSamplesPerClassUnlabeled) )
  35. settings.i_numSamplesPerClassUnlabeled = 50;
  36. end
  37. if ( ~isfield( settings, 'i_numSamplesPerClassTest') || isempty(settings.i_numSamplesPerClassTest) )
  38. settings.i_numSamplesPerClassTest = 10;
  39. end
  40. % how many experiments do you want to average?
  41. if ( ~isfield( settings, 'i_numRepetitions') || isempty(settings.i_numRepetitions) )
  42. i_numRepetitions = 10;
  43. end
  44. timesBatch = zeros ( i_numRepetitions, (settings.i_numClasses*settings.i_numSamplesPerClassUnlabeled+1) );
  45. timesEfficient = zeros ( i_numRepetitions, (settings.i_numClasses*settings.i_numSamplesPerClassUnlabeled+1) );
  46. arrBatch = zeros ( i_numRepetitions, (settings.i_numClasses*settings.i_numSamplesPerClassUnlabeled+1) );
  47. arrEfficient = zeros ( i_numRepetitions, (settings.i_numClasses*settings.i_numSamplesPerClassUnlabeled+1) );
  48. for i_idxRep = 1:i_numRepetitions
  49. %% (1) sample some features for initial train set, separate test set, and additional separate set to be added incrementally
  50. % sample training data
  51. settings.i_numSamplesPerClass = settings.i_numSamplesPerClassTrain;
  52. labeledData = sampleData( settings );
  53. % sample unlabeled data
  54. settings.i_numSamplesPerClass = settings.i_numSamplesPerClassUnlabeled;
  55. unlabeledData = sampleData( settings );
  56. % sample test data
  57. settings.i_numSamplesPerClass = settings.i_numSamplesPerClassTest;
  58. testData = sampleData( settings );
  59. % in which order to we want to add the 'unlabeled' sampels?
  60. idxToAdd = randperm( length(unlabeledData.y) );
  61. %% (2) run both techniques : efficient updates versus batch training
  62. b_efficientUpdates = false;
  63. resultsBatch = simulateIncrementalLearning ( labeledData, unlabeledData, testData, b_efficientUpdates, idxToAdd );
  64. timesBatch(i_idxRep, :) = resultsBatch.times;
  65. arrBatch(i_idxRep, :) = resultsBatch.ARR;
  66. clear ( 'resultsBatch' );
  67. b_efficientUpdates = true;
  68. resultsEfficient = simulateIncrementalLearning ( labeledData, unlabeledData, testData, b_efficientUpdates, idxToAdd );
  69. timesEfficient(i_idxRep, :) = resultsEfficient.times;
  70. arrEfficient(i_idxRep, :) = resultsEfficient.ARR;
  71. clear ( 'resultsEfficient' );
  72. end
  73. %% (3) evaluation
  74. % compute relevant values
  75. timesBatch = mean ( timesBatch ) ;
  76. timesEfficient = mean ( timesEfficient ) ;
  77. arrBatch = mean( arrBatch );
  78. arrEfficient = mean ( arrEfficient );
  79. % setup variables
  80. if ( ( ~isfield(settings,'s_legendLocation')) || isempty(settings.s_legendLocation) )
  81. s_legendLocation = 'NorthEast';
  82. else
  83. s_legendLocation = settings.s_legendLocation;
  84. end
  85. if ( ( ~isfield(settings,'i_fontSize')) || isempty(settings.i_fontSize) )
  86. i_fontSize = 12;
  87. else
  88. i_fontSize = settings.i_fontSize;
  89. end
  90. if ( ( ~isfield(settings,'i_fontSizeAxis')) || isempty(settings.i_fontSizeAxis) )
  91. i_fontSizeAxis = 16;
  92. else
  93. i_fontSizeAxis = settings.i_fontSizeAxis;
  94. end
  95. if ( ( ~isfield(settings,'c')) || isempty(settings.c) )
  96. c={[0,0,1], [1,0,0]};
  97. else
  98. c = settings.c;
  99. end
  100. if ( ( ~isfield(settings,'lineStyle')) || isempty(settings.lineStyle) )
  101. lineStyle={'-', '--' };
  102. else
  103. lineStyle = settings.lineStyle;
  104. end
  105. if ( ( ~isfield(settings,'marker')) || isempty(settings.marker) )
  106. marker={'none', 'none'};
  107. else
  108. marker = settings.marker;
  109. end
  110. if ( ( ~isfield(settings,'linewidth')) || isempty(settings.linewidth) )
  111. linewidth=3;
  112. else
  113. linewidth = settings.linewidth;
  114. end
  115. %% plot computation times (blue should be lower than red)
  116. fig_timeComp = figure;
  117. set ( fig_timeComp, 'name', 'Computation Times for Model Updates');
  118. hold on;
  119. plot ( timesEfficient, 'Color', c{ 1 }, 'LineStyle', lineStyle{ 1 }, 'Marker', marker{ 1 }, ...
  120. 'LineWidth', linewidth, 'MarkerSize',8 );
  121. plot ( timesBatch, 'Color', c{ 2 }, 'LineStyle', lineStyle{ 2 }, 'Marker', marker{ 2 }, ...
  122. 'LineWidth', linewidth, 'MarkerSize',8);
  123. leg=legend( {'Efficient', 'Batch'}, 'Location', s_legendLocation,'fontSize', i_fontSize,'LineWidth', 3);
  124. xlabel('Number of samples added');
  125. ylabel({'Time spent for model update [s]'});
  126. text_h=findobj(gca,'type','text');
  127. set(text_h,'FontSize',i_fontSize);
  128. set(gca, 'FontSize', i_fontSize);
  129. set(get(gca,'YLabel'), 'FontSize', i_fontSizeAxis);
  130. set(get(gca,'XLabel'), 'FontSize', i_fontSizeAxis);
  131. hold off;
  132. %% plot accuracies (blue and red should be identical)
  133. fig_accuracyComp = figure;
  134. set ( fig_accuracyComp, 'name', 'Accuracy over time');
  135. hold on;
  136. plot ( arrEfficient, 'Color', c{ 1 }, 'LineStyle', lineStyle{ 1 }, 'Marker', marker{ 1 }, ...
  137. 'LineWidth', linewidth, 'MarkerSize',8 );
  138. plot ( arrBatch, 'Color', c{ 2 }, 'LineStyle', lineStyle{ 2 }, 'Marker', marker{ 2 }, ...
  139. 'LineWidth', linewidth, 'MarkerSize',8);
  140. leg=legend( {'Efficient', 'Batch'}, 'Location', s_legendLocation,'fontSize', i_fontSize,'LineWidth', 3);
  141. xlabel('Number of samples added');
  142. ylabel({'Time spent for model update [s]'});
  143. text_h=findobj(gca,'type','text');
  144. set(text_h,'FontSize',i_fontSize);
  145. set(gca, 'FontSize', i_fontSize);
  146. set(get(gca,'YLabel'), 'FontSize', i_fontSizeAxis);
  147. set(get(gca,'XLabel'), 'FontSize', i_fontSizeAxis);
  148. hold off;
  149. pause;
  150. close ( fig_timeComp );
  151. close ( fig_accuracyComp );
  152. end
  153. function data = sampleData( settings )
  154. if ( isfield( settings, 'i_numDim') && ~isempty(settings.i_numDim) )
  155. i_numDim = settings.i_numDim;
  156. else
  157. i_numDim = 5;
  158. end
  159. if ( isfield( settings, 'd_mean') && ~isempty(settings.d_mean) )
  160. d_mean = settings.d_mean;
  161. else
  162. d_mean = 3.0;
  163. end
  164. if ( isfield( settings, 'd_var') && ~isempty(settings.d_var) )
  165. d_var = settings.d_var;
  166. else
  167. d_var = 1.0;
  168. end
  169. if ( isfield( settings, 'i_numClasses') && ~isempty(settings.i_numClasses) )
  170. i_numClasses = settings.i_numClasses;
  171. else
  172. i_numClasses = 3;
  173. end
  174. if ( isfield( settings, 'offsetSpecific') && ~isempty(settings.offsetSpecific) )
  175. offsetSpecific = settings.offsetSpecific;
  176. else
  177. offsetSpecific = [ 0.0, 2.0, 4.0 ];
  178. end
  179. if ( isfield( settings, 'varSpecific') && ~isempty(settings.varSpecific) )
  180. varSpecific = settings.varSpecific;
  181. else
  182. varSpecific = [2.0, 0.5, 1.0];
  183. end
  184. if ( isfield( settings, 'dimToChange') && ~isempty(settings.dimToChange) )
  185. dimToChange = settings.dimToChange;
  186. else
  187. dimToChange = [ [1,3] , [2,4] , [5] ];
  188. end
  189. if ( isfield( settings, 'i_numSamplesPerClass') && ~isempty(settings.i_numSamplesPerClass) )
  190. i_numSamplesPerClass = settings.i_numSamplesPerClass;
  191. else
  192. i_numSamplesPerClass = 2;
  193. end
  194. % randomly compute some features
  195. data.X = abs ( d_mean + d_var.*randn(i_numClasses*i_numSamplesPerClass,i_numDim) );
  196. % disturbe features for specific classes
  197. for i_clIdx = 1:i_numClasses
  198. i_idxStart = (i_clIdx-1)*i_numSamplesPerClass+1;
  199. i_idxEnd = i_clIdx*i_numSamplesPerClass;
  200. data.X( i_idxStart:i_idxEnd, dimToChange(i_clIdx) ) = ...
  201. abs ( data.X(i_idxStart:i_idxEnd,1) + ...
  202. offsetSpecific( i_clIdx) + ... %add class offset
  203. varSpecific ( i_clIdx) .*randn(i_numSamplesPerClass,1) ... % add class variance
  204. );
  205. end
  206. % normalize features, thereby simulate histograms, which commonly occure in
  207. % computer vision applications
  208. data.X = bsxfun(@times, data.X, 1./(sum(data.X, 2)));
  209. % adapt class labels
  210. data.y = bsxfun(@times, ones(i_numSamplesPerClass,1), 1:i_numClasses );
  211. data.y = data.y(:);
  212. end