sq_dist.m 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. % sq_dist - a function to compute a matrix of all pairwise squared distances
  2. % between two sets of vectors, stored in the columns of the two matrices, a
  3. % (of size D by n) and b (of size D by m). If only a single argument is given
  4. % or the second matrix is empty, the missing matrix is taken to be identical
  5. % to the first.
  6. %
  7. % Usage: C = sq_dist(a, b)
  8. % or: C = sq_dist(a) or equiv.: C = sq_dist(a, [])
  9. %
  10. % Where a is of size Dxn, b is of size Dxm (or empty), C is of size nxm.
  11. %
  12. % Copyright (c) by Carl Edward Rasmussen and Hannes Nickisch, 2010-12-13.
  13. function C = sq_dist(a, b)
  14. % borrowed from gpml-toolbox of Rasmussen and Nikkisch
  15. % see http://www.gaussianprocess.org/gpml/code/matlab/doc/
  16. if nargin<1 || nargin>3 || nargout>1, error('Wrong number of arguments.'); end
  17. bsx = exist('bsxfun','builtin'); % since Matlab R2007a 7.4.0 and Octave 3.0
  18. if ~bsx, bsx = exist('bsxfun'); end % bsxfun is not yes "builtin" in Octave
  19. [D, n] = size(a);
  20. % Computation of a^2 - 2*a*b + b^2 is less stable than (a-b)^2 because numerical
  21. % precision can be lost when both a and b have very large absolute value and the
  22. % same sign. For that reason, we subtract the mean from the data beforehand to
  23. % stabilise the computations. This is OK because the squared error is
  24. % independent of the mean.
  25. if nargin==1 % subtract mean
  26. mu = mean(a,2);
  27. if bsx
  28. a = bsxfun(@minus,a,mu);
  29. else
  30. a = a - repmat(mu,1,size(a,2));
  31. end
  32. b = a; m = n;
  33. else
  34. [d, m] = size(b);
  35. if d ~= D, error('Error: column lengths must agree.'); end
  36. mu = (m/(n+m))*mean(b,2) + (n/(n+m))*mean(a,2);
  37. if bsx
  38. a = bsxfun(@minus,a,mu); b = bsxfun(@minus,b,mu);
  39. else
  40. a = a - repmat(mu,1,n); b = b - repmat(mu,1,m);
  41. end
  42. end
  43. if bsx % compute squared distances
  44. C = bsxfun(@plus,sum(a.*a,1)',bsxfun(@minus,sum(b.*b,1),2*a'*b));
  45. else
  46. C = repmat(sum(a.*a,1)',1,m) + repmat(sum(b.*b,1),n,1) - 2*a'*b;
  47. end
  48. C = max(C,0); % numerical noise can cause C to negative i.e. C > -1e-14