test_multiclass_gp.m 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. function [pred_mean, pred_labels, pred_var] = test_multiclass_gp(model, Ks, Kss)
  2. % Prediction with multi-class GP models, outputs predictive mean and variance
  3. %
  4. % function [pred_mean, pred_labels, pred_var] = test_multiclass_gp(model, Ks, Kss)
  5. %
  6. % INPUT: model -- model structure obtained from function "learn_multiclass_gp"
  7. % Ks -- n x m matrix of kernel values between n training and m test samples
  8. % Kss -- m x 1 vector of self-similarities of test samples
  9. %
  10. % OUTPUT: pred_mean -- m x c matrix of predicitive mean values for each of the m test samples and each of the c binary tasks
  11. % pred_labels -- m x 1 vector containing predicted multi-class label for each test sample
  12. % pred_var -- m x 1 vector containing predicted variance for each test sample
  13. %
  14. % NOTE: to get the maximum mean value for each test sample out of the c binary tasks, use: max_mean = max(pred_mean,[],2);
  15. % Copyright (c) by Alexander Freytag, 2013-11-13.
  16. if ( nargin < 2)
  17. Kss = [];
  18. end
  19. pred_mean = zeros( size(Ks,2),length(model.unique_labels) );
  20. % loop over classes and thus over binary one-vs-all tasks
  21. for k=1:length(model.unique_labels)
  22. pred_mean(:,k) = Ks'*model.alpha{k};
  23. end
  24. % obtain predicted labels using max-pooling
  25. [~, maxID] = max(pred_mean,[],2);
  26. pred_labels = model.unique_labels(maxID);
  27. % compute predictive variances if necessary
  28. if nargout > 2
  29. v = model.L\Ks;
  30. if ( isempty(Kss) )
  31. disp('No Kss given - break!');
  32. return;
  33. end
  34. pred_var = Kss' - sum(v .* v)' + model.noise;
  35. end
  36. end