|
@@ -5,9 +5,9 @@ function multiclass_gp_1D_example()
|
|
|
% Copyright (c) by Alexander Freytag, 2013-11-13.
|
|
|
|
|
|
% get some artificial 1D data
|
|
|
- train_data = [1;2.5;4.5];
|
|
|
+ train_data = [1;2.5;4.5]';
|
|
|
train_labels = [1; 2; 4];
|
|
|
- test_data = (0.1:0.01:5)';
|
|
|
+ test_data = (0.1:0.01:5);
|
|
|
|
|
|
% some default settings of hyperparameters
|
|
|
loghyper = [-1 0];
|
|
@@ -15,16 +15,16 @@ function multiclass_gp_1D_example()
|
|
|
|
|
|
|
|
|
% learn multi-class GP model
|
|
|
- K = feval('covSEiso', loghyper, train_data);
|
|
|
+ K = feval('covSEiso', loghyper, train_data');
|
|
|
model = learn_multiclass_gp(K, train_labels, gpnoise);
|
|
|
|
|
|
% evaluate model on test data
|
|
|
- Ks = feval('covSEiso', loghyper, train_data, test_data);
|
|
|
- Kss = feval('covSEiso', loghyper, test_data, 'diag');
|
|
|
- [mu, pred_labels, variance] = test_multiclass_gp(model, Ks, Kss');
|
|
|
+ Ks = feval('covSEiso', loghyper, train_data', test_data');
|
|
|
+ Kss = feval('covSEiso', loghyper, test_data', 'diag');
|
|
|
+ [mu, pred_labels, variance] = test_multiclass_gp(model, Ks, Kss);
|
|
|
|
|
|
% visualize everything nicely
|
|
|
- f1 = figure(1);
|
|
|
+ hfigMean = figure(1);
|
|
|
plot(train_data(1), 1, 'bx');
|
|
|
title('Mean curve of each binary task');
|
|
|
|
|
@@ -37,25 +37,106 @@ function multiclass_gp_1D_example()
|
|
|
plot(test_data, mu(:,3), 'r');
|
|
|
hold off
|
|
|
|
|
|
- f2 = figure(2);
|
|
|
+ hfigPred = figure(2);
|
|
|
plot(test_data, pred_labels, 'kx');
|
|
|
title('Predicted multi-class labels');
|
|
|
|
|
|
- f3 = figure(3);
|
|
|
+ hfigPosterior = figure(3);
|
|
|
title('Multi-class posterior mean and variance');
|
|
|
colors = {'b', 'g', 'r'};
|
|
|
hold on
|
|
|
max_mean = max(mu,[],2);
|
|
|
lower = max_mean-sqrt(variance);
|
|
|
upper = max_mean+sqrt(variance);
|
|
|
- p = [test_data, lower; flipdim(test_data,1),flipdim(upper,1)];
|
|
|
+ p = [test_data', lower; flipdim(test_data',1),flipdim(upper,1)];
|
|
|
fill(p(:,1), p(:,2), 'y');
|
|
|
+
|
|
|
for k=1:length(model.unique_labels)
|
|
|
|
|
|
tmp_ID = pred_labels == model.unique_labels(k);
|
|
|
- plot(test_data(tmp_ID),mu(tmp_ID,k),colors{k}, 'LineWidth', 2);
|
|
|
+ plot( test_data(tmp_ID)', ...
|
|
|
+ mu(tmp_ID,k),...
|
|
|
+ colors{k}, ...
|
|
|
+ 'LineWidth', 2 ...
|
|
|
+ );
|
|
|
|
|
|
end
|
|
|
hold off
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ %% update
|
|
|
+
|
|
|
+ disp( 'Now add one example...' )
|
|
|
+
|
|
|
+ pause
|
|
|
+
|
|
|
+
|
|
|
+ train_data_new = [2.0];
|
|
|
+ train_labels_new = [1];
|
|
|
+
|
|
|
+ Ks = feval('covSEiso', loghyper, train_data', train_data_new');
|
|
|
+ Kss = feval('covSEiso', loghyper, train_data_new', 'diag');
|
|
|
+ model = update_multiclass_gp(model, Kss, Ks, train_labels_new);
|
|
|
+
|
|
|
+
|
|
|
+ train_data = [train_data, train_data_new];
|
|
|
+
|
|
|
+ % evaluate model on test data
|
|
|
+ Ks = feval('covSEiso', loghyper, train_data', test_data');
|
|
|
+ Kss = feval('covSEiso', loghyper, test_data', 'diag');
|
|
|
+ [mu, pred_labels, variance] = test_multiclass_gp(model, Ks, Kss);
|
|
|
+
|
|
|
+ % visualize everything nicely
|
|
|
+ hfigMeanUpd = figure(4);
|
|
|
+ hold on
|
|
|
+ plot(train_data(1), 1, 'bx');
|
|
|
+ plot(train_data(4), 1, 'bx');
|
|
|
+ title('Mean curve of each binary task');
|
|
|
+
|
|
|
+
|
|
|
+ plot(train_data(2), 1, 'gx');
|
|
|
+ plot(train_data(3), 1, 'rx');
|
|
|
+
|
|
|
+ plot(test_data, mu(:,1), 'b');
|
|
|
+ plot(test_data, mu(:,2), 'g');
|
|
|
+ plot(test_data, mu(:,3), 'r');
|
|
|
+ hold off
|
|
|
+
|
|
|
+ hfigPredUpd = figure(5);
|
|
|
+ plot(test_data', pred_labels, 'kx');
|
|
|
+ title('Predicted multi-class labels');
|
|
|
+
|
|
|
+ hfigPosteriorUpd = figure(6);
|
|
|
+ title('Multi-class posterior mean and variance');
|
|
|
+ colors = {'b', 'g', 'r'};
|
|
|
+ hold on
|
|
|
+ max_mean = max(mu,[],2);
|
|
|
+ lower = max_mean-sqrt(variance);
|
|
|
+ upper = max_mean+sqrt(variance);
|
|
|
+ p = [test_data', lower; flipdim(test_data',1),flipdim(upper,1)];
|
|
|
+ fill(p(:,1), p(:,2), 'y');
|
|
|
+ for k=1:length(model.unique_labels)
|
|
|
+
|
|
|
+ tmp_ID = pred_labels == model.unique_labels(k);
|
|
|
+ plot( test_data(tmp_ID)', ...
|
|
|
+ mu(tmp_ID,k), ...
|
|
|
+ colors{k}, ...
|
|
|
+ 'LineWidth', 2 ...
|
|
|
+ );
|
|
|
+
|
|
|
+ end
|
|
|
+ hold off
|
|
|
+
|
|
|
+
|
|
|
+ pause
|
|
|
+
|
|
|
+ close ( hfigMean );
|
|
|
+ close ( hfigPred );
|
|
|
+ close ( hfigPosterior );
|
|
|
+ %
|
|
|
+ close ( hfigMeanUpd );
|
|
|
+ close ( hfigPredUpd );
|
|
|
+ close ( hfigPosteriorUpd );
|
|
|
|
|
|
end
|