multiclass_gp_1D_example.m 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. function multiclass_gp_1D_example()
  2. % plot GP model outputs for 1D artificial data
  3. % play around with param settings to become familiar with GP models
  4. %
  5. % Copyright (c) by Alexander Freytag, 2013-11-13.
  6. % get some artificial 1D data
  7. train_data = [1;2.5;4.5]';
  8. train_labels = [1; 2; 4]';
  9. test_data = (0.1:0.01:5);
  10. % some default settings of hyperparameters
  11. loghyper = [-1 0];
  12. gpnoise = 0.01;
  13. % learn multi-class GP model
  14. K = feval('covSEiso', loghyper, train_data');
  15. model = learn_multiclass_gp(K, train_labels, gpnoise);
  16. % evaluate model on test data
  17. Ks = feval('covSEiso', loghyper, train_data', test_data');
  18. Kss = feval('covSEiso', loghyper, test_data', 'diag');
  19. [mu, pred_labels, variance] = test_multiclass_gp(model, Ks, Kss);
  20. % visualize everything nicely
  21. hfigMean = figure(1);
  22. plot(train_data(1), 1, 'bx');
  23. title('Mean curve of each binary task');
  24. hold on
  25. plot(train_data(2), 1, 'gx');
  26. plot(train_data(3), 1, 'rx');
  27. plot(test_data, mu(1,:), 'b');
  28. plot(test_data, mu(2,:), 'g');
  29. plot(test_data, mu(3,:), 'r');
  30. hold off
  31. hfigPred = figure(2);
  32. plot(test_data, pred_labels, 'kx');
  33. title('Predicted multi-class labels');
  34. hfigPosterior = figure(3);
  35. title('Multi-class posterior mean and variance');
  36. colors = {'b', 'g', 'r'};
  37. hold on
  38. max_mean = max(mu,[],1);
  39. lower = max_mean-sqrt(variance);
  40. upper = max_mean+sqrt(variance);
  41. px = [test_data, flipdim(test_data,2)];
  42. py = [lower, flipdim(upper,2)];
  43. fill(px, py, 'y');
  44. for k=1:length(model.unique_labels)
  45. tmp_ID = pred_labels == model.unique_labels(k);
  46. plot( test_data(tmp_ID), ...
  47. mu(k,tmp_ID),...
  48. colors{k}, ...
  49. 'LineWidth', 2 ...
  50. );
  51. end
  52. hold off
  53. %% update
  54. disp( 'Now add one example...' )
  55. pause
  56. train_data_new = [2.0];
  57. train_labels_new = [1];
  58. Ks = feval('covSEiso', loghyper, train_data', train_data_new');
  59. Kss = feval('covSEiso', loghyper, train_data_new', 'diag');
  60. model = update_multiclass_gp(model, Kss, Ks, train_labels_new);
  61. train_data = [train_data, train_data_new];
  62. % evaluate model on test data
  63. Ks = feval('covSEiso', loghyper, train_data', test_data');
  64. Kss = feval('covSEiso', loghyper, test_data', 'diag');
  65. [mu, pred_labels, variance] = test_multiclass_gp(model, Ks, Kss);
  66. % visualize everything nicely
  67. hfigMeanUpd = figure(4);
  68. hold on
  69. plot(train_data(1), 1, 'bx');
  70. plot(train_data(4), 1, 'bx');
  71. title('Mean curve of each binary task');
  72. plot(train_data(2), 1, 'gx');
  73. plot(train_data(3), 1, 'rx');
  74. plot(test_data, mu(1,:), 'b');
  75. plot(test_data, mu(2,:), 'g');
  76. plot(test_data, mu(3,:), 'r');
  77. hold off
  78. hfigPredUpd = figure(5);
  79. plot(test_data, pred_labels, 'kx');
  80. title('Predicted multi-class labels');
  81. hfigPosteriorUpd = figure(6);
  82. title('Multi-class posterior mean and variance');
  83. colors = {'b', 'g', 'r'};
  84. hold on
  85. max_mean = max(mu,[],1);
  86. lower = max_mean-sqrt(variance);
  87. upper = max_mean+sqrt(variance);
  88. px = [test_data, flipdim(test_data,2)];
  89. py = [lower, flipdim(upper,2)];
  90. fill(px, py, 'y');
  91. for k=1:length(model.unique_labels)
  92. tmp_ID = pred_labels == model.unique_labels(k);
  93. plot( test_data(tmp_ID), ...
  94. mu(k,tmp_ID), ...
  95. colors{k}, ...
  96. 'LineWidth', 2 ...
  97. );
  98. end
  99. hold off
  100. pause
  101. close ( hfigMean );
  102. close ( hfigPred );
  103. close ( hfigPosterior );
  104. %
  105. close ( hfigMeanUpd );
  106. close ( hfigPredUpd );
  107. close ( hfigPosteriorUpd );
  108. end