simulateIncrementalLearning.m 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. function results = simulateIncrementalLearning ( labeledData, unlabeledData, testData, b_efficientUpdates, idxToAdd )
  2. % Copyright (c) by Alexander Freytag, 2013-11-13.
  3. if ( (nargin < 4) || isempty (b_efficientUpdates) )
  4. b_efficientUpdates = true;
  5. end
  6. if ( (nargin < 5) || isempty (idxToAdd) )
  7. idxToAdd = 1:length(unlabeledData.y);
  8. end
  9. labeled_X = labeledData.X;
  10. labeled_y = labeledData.y;
  11. unlabeled_X = unlabeledData.X;
  12. unlabeled_y = unlabeledData.y;
  13. test_X = testData.X;
  14. test_y = testData.y;
  15. % output
  16. ARRscore = zeros ( 1+length(unlabeled_y ) ,1 );
  17. times = zeros ( 1+length(unlabeled_y ) ,1 );
  18. %% (2.1) initially train the model
  19. % settings
  20. gpnoise = 0.1;
  21. cov = {'covmin'};
  22. loghyper = [];
  23. K = feval(cov{:},loghyper, labeled_X');
  24. timeStamp_trainStart = tic;
  25. model = learn_multiclass_gp(K, labeled_y', gpnoise);
  26. times(1) = toc(timeStamp_trainStart);
  27. %% (2.2) test initial model
  28. %copute similarities to test samples
  29. Ks = feval(cov{:},loghyper, labeled_X', test_X');
  30. [~, pred_labels] = test_multiclass_gp(model, Ks);
  31. % multi-class accuracy: 1/(# classes) \sum accuracy_per_class
  32. numCorrect = norm(double(pred_labels == test_y'), 1);
  33. testSize = double(size(test_y,1));
  34. ARRscore(1) = double(numCorrect)/double(testSize);
  35. %% (3) perform incremental updates
  36. for i_idxAdd = 1:length(unlabeled_y)
  37. %%%%%%%%%%%%%%
  38. %% update
  39. %%%%%%%%%%%%%%
  40. % get new sample to add
  41. xNew = unlabeled_X( :, idxToAdd(i_idxAdd) );
  42. % simulate asking of user
  43. yNew = unlabeled_y ( idxToAdd(i_idxAdd) );
  44. if ( b_efficientUpdates )
  45. % compute kernel values
  46. kss = feval(cov{:},loghyper, xNew');
  47. ksNewSample = feval(cov{:},loghyper, labeled_X', xNew');
  48. timeStamp_updateStart = tic;
  49. model = update_multiclass_gp(model, kss, ksNewSample, yNew);
  50. times(i_idxAdd+1) = toc(timeStamp_updateStart);
  51. labeled_X = [ labeled_X, xNew];
  52. else
  53. kss = feval(cov{:},loghyper, xNew');
  54. ksNewSample = feval(cov{:},loghyper, labeled_X', xNew');
  55. K = [K, ksNewSample;ksNewSample',kss];
  56. labeled_X = [ labeled_X, xNew];
  57. labeled_y = [ labeled_y; yNew];
  58. timeStamp_updateStart = tic;
  59. model = learn_multiclass_gp(K, labeled_y', gpnoise);
  60. times(i_idxAdd+1) = toc(timeStamp_updateStart);
  61. end
  62. %%%%%%%%%%%%%
  63. %% test
  64. %%%%%%%%%%%%%
  65. Ks = [Ks; feval(cov{:},loghyper, xNew', test_X')];
  66. [~, pred_labels] = test_multiclass_gp(model, Ks);
  67. % multi-class accuracy: 1/(# classes) \sum accuracy_per_class
  68. numCorrect = norm(double(pred_labels == test_y'), 1);
  69. testSize = double(size(test_y,1));
  70. ARRscore(i_idxAdd+1) = double(numCorrect)/double(testSize);
  71. end
  72. results.ARR = ARRscore;
  73. results.times = times;
  74. end