multiclass_gp_1D_example.m 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. function multiclass_gp_1D_example()
  2. % plot GP model outputs for 1D artificial data
  3. % play around with param settings to become familiar with GP models
  4. %
  5. % Copyright (c) by Alexander Freytag, 2013-11-13.
  6. % get some artificial 1D data
  7. train_data = [1;2.5;4.5];
  8. train_labels = [1; 2; 4];
  9. test_data = (0.1:0.01:5)';
  10. % some default settings of hyperparameters
  11. loghyper = [-1 0];
  12. gpnoise = 0.01;
  13. % learn multi-class GP model
  14. K = feval('covSEiso', loghyper, train_data);
  15. model = learn_multiclass_gp(K, train_labels, gpnoise);
  16. % evaluate model on test data
  17. Ks = feval('covSEiso', loghyper, train_data, test_data);
  18. Kss = feval('covSEiso', loghyper, test_data, 'diag');
  19. [mu, pred_labels, variance] = test_multiclass_gp(model, Ks, Kss');
  20. % visualize everything nicely
  21. f1 = figure(1);
  22. plot(train_data(1), 1, 'bx');
  23. title('Mean curve of each binary task');
  24. hold on
  25. plot(train_data(2), 1, 'gx');
  26. plot(train_data(3), 1, 'rx');
  27. plot(test_data, mu(:,1), 'b');
  28. plot(test_data, mu(:,2), 'g');
  29. plot(test_data, mu(:,3), 'r');
  30. hold off
  31. f2 = figure(2);
  32. plot(test_data, pred_labels, 'kx');
  33. title('Predicted multi-class labels');
  34. f3 = figure(3);
  35. title('Multi-class posterior mean and variance');
  36. colors = {'b', 'g', 'r'};
  37. hold on
  38. max_mean = max(mu,[],2);
  39. lower = max_mean-sqrt(variance);
  40. upper = max_mean+sqrt(variance);
  41. p = [test_data, lower; flipdim(test_data,1),flipdim(upper,1)];
  42. fill(p(:,1), p(:,2), 'y');
  43. for k=1:length(model.unique_labels)
  44. tmp_ID = pred_labels == model.unique_labels(k);
  45. plot(test_data(tmp_ID),mu(tmp_ID,k),colors{k}, 'LineWidth', 2);
  46. end
  47. hold off
  48. end