Mercurial > forge
changeset 4767:8e47c13c48bc octave-forge
selecting RBF centers via k-means clustering
author | highegg |
---|---|
date | Thu, 17 Apr 2008 13:53:44 +0000 |
parents | f8e8ae7a7575 |
children | ceaeef5e7db9 |
files | main/octgpr/ChangeLog main/octgpr/DESCRIPTION main/octgpr/inst/demo_octgpr.m main/octgpr/inst/rbf_centers.m main/octgpr/src/ChangeLog main/octgpr/src/Makefile.in main/octgpr/src/configure.in main/octgpr/src/pdist2_mw.cc |
diffstat | 8 files changed, 601 insertions(+), 131 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/main/octgpr/ChangeLog Thu Apr 17 13:53:44 2008 +0000 @@ -0,0 +1,81 @@ +2008-04-17 Jaroslav Hajek <highegg@gmail.com> + + * inst/demo_octgpr.m: refactor demo + * inst/rbf_centers.m: use kmeans++ initialization due to Arthur & + Vassilvitskii + * src/Changelog -> ChangeLog: transform to package root dir + +2008-04-15 Jaroslav Hajek <highegg@gmail.com> + + * src/pdist2_mw.cc: new source. + * inst/rbf_centers.m: new function. + +2008-04-04 Jaroslav Hajek <highegg@gmail.com> + + * src/nllgpr.f: use QR decomposition instead of normal equations + for mean fitting. + * src/train.c, src/setup.c: change workspace size for nllgpr + * src/configure.in: update the LAPACK routines checked + * src/configure: regenerate + +2008-03-26 Jaroslav Hajek <highegg@gmail.com> + + * src/config.log, src/config.status, src/config.h, src/Makefile: remove. + * src/configure.in: add support for -fPIC flags. + * src/configure, src/config.h.in: regenerate. + * src/Makefile.in: add distclean target + +2008-02-28 Jaroslav Hajek <highegg@gmail.com> + + * src/corrf.f: added Matern-3/2 and Matern-5/2 correlation funcs + * src/get_corrf.c, src/forsubs.h: included the new correlation funcs + * src/gpr_train.cc: documentation for the new correlation funcs + +2008-02-24 Jaroslav Hajek <highegg@gmail.com> + + * src/optdrv.f: implement objective reduction stopping criterion + * src/gprmod.h, src/train.c, src/gpr_train.cc: add support for the new + feature + * src/forsubs.h src/nl0gpr.f: correct subroutine name + * src/train.c: correct CP array size + +2008-02-19 Jaroslav Hajek <highegg@gmail.com> + + * src/*.{h,c,cc}: adjusted C/C++ sources to meet GNU coding standards + +2008-02-18 Jaroslav Hajek <highegg@gmail.com> + + * src/forsubs.h src/gprmod.h src/predict.c src/train.c src/setup.c: + added const modifiers where appropriate. + * src/grp_train, src/gpr_predict: replaced Array::fortran_vec with + Array::data where appropriate, to prevent unnecessary copying. + +2008-02-15 Jaroslav Hajek <highegg@gmail.com> + + * inst/demo_octgpr.m: improved demo. + +2008-02-13 Jaroslav Hajek <highegg@gmail.com> + + * src/train.c: correct allocation and copying result values. + * src/nldgpr.f: compute 2nd derivative by a better formula. + * src/gpr_train.cc: improved documentation + * src/gpr_predict.cc: improved documentation + +2008-02-12 Jaroslav Hajek <highegg@gmail.com> + + * src/dscrot.f: correct the W*D*x and W'*D*x sequence. + * src/optdrv.f: correct calls to dscrot.f + +2008-02-08 Jaroslav Hajek <highegg@gmail.com> + + * src/train.c: corrected C->Fortran calls in to use proper #defines from forsubs.h + * src/vmfac.f vmcmp.f trstep.f: deleted (no use) + * src/optdrv.f: cosmetic changes + * inst/demo_octgpr.m: added copyright + * TODO: modified + +2008-02-07 Jaroslav Hajek <highegg@gmail.com> + + * src/gpr_train.cc: corrected inline Texinfo documentation + * src/ChangeLog: added this file +
--- a/main/octgpr/DESCRIPTION Thu Apr 17 08:09:27 2008 +0000 +++ b/main/octgpr/DESCRIPTION Thu Apr 17 13:53:44 2008 +0000 @@ -1,5 +1,5 @@ Name: OctGPR -Version: 1.1.2 +Version: 1.1.3 Date: 2008-24-02 Author: Jaroslav Hajek (highegg@gmail.com) Title: Package for full dense Gaussian Process Regression
--- a/main/octgpr/inst/demo_octgpr.m Thu Apr 17 08:09:27 2008 +0000 +++ b/main/octgpr/inst/demo_octgpr.m Thu Apr 17 13:53:44 2008 +0000 @@ -18,71 +18,140 @@ % along with this software; see the file COPYING. If not, see % <http://www.gnu.org/licenses/>. % -% demo of Gaussian Process Regression package -disp("2-dimensional GPR demo"); -disp("(set global variable nsamp to the number of random samples)"); -% define the test function (the well-known matlab "peaks") -function z = testfun(x,y) - z = 4 + 3 * (1-x).^2 .* exp(-(x.^2) - (y+1).^2) - ... - 10 * (x/5 - x.^3 - y.^5) .* exp(-x.^2 - y.^2)- ... - 1/3 * exp(-(x+1).^2 - y.^2); - -end +% -*- texinfo -*- +% @deftypefn {Function File} demo_octgpr (1, nsamp = 150) +% @deftypefnx {Function File} demo_octgpr (2, ncnt = 20, npt = 500) +% OctGPR package demo function. +% First argument selects available demos: +% +% @itemize +% @item 1. GPR regression demo @* +% A function is sampled (with small noise), then reconstructed using GPR +% regression. @var{nsamp} specifies the number of samples. +% @seealso{gpr_train, gpr_predict} +% @item 2. RBF centers selection demo @* +% Radial basis centers are selected amongst random points. +% @var{ncnt} specifies number of centers, @var{npt} number of points. +% @seealso{rbf_centers} +% @end itemize +% @end deftypefn +function demo_octgpr (number, varargin) + switch (number) + case 1 + demo_octgpr1 (varargin{:}) + case 2 + demo_octgpr2 (varargin{:}) + otherwise + error ("demo_octgpr: invalid demo number") + endswitch +endfunction -tit = "matlab ""peaks"" surface"; -disp(tit); -% create the mesh onto which to interpolate -t = linspace(-3,3,50); -[xi,yi] = meshgrid(t,t); +% define the test function (the well-known matlab "peaks" plus some sines) +function z = testfun1 (x, y) + z = 4 + 3 * (1-x).^2 .* exp(-(x.^2) - (y+1).^2) ... + + 10 * (x/5 - x.^3 - y.^5) .* exp(-x.^2 - y.^2) ... + - 1/3 * exp(-(x+1).^2 - y.^2) ... + + 2*sin (x + y + 1e-1*x.*y); +endfunction + +function demo_octgpr1 (nsamp = 150) + tit = "a peaked surface"; + disp (tit); + + % create the mesh onto which to interpolate + t = linspace (-3, 3, 50); + [xi,yi] = meshgrid (t, t); -% evaluate -zi = testfun(xi,yi); -zimax = max(vec(zi)); zimin = min(vec(zi)); -subplot(1,2,1); -mesh(xi,yi,zi); -title(tit); -pause; + % evaluate + zi = testfun1 (xi, yi); + zimax = max (vec (zi)); zimin = min (vec (zi)); + subplot (2, 2, 1); + mesh (xi, yi, zi); + title (tit); + subplot (2, 2, 3); + contourf (xi, yi, zi, 20); + pause; + + if (!exist ("nsamp", "var") || !isnumeric (nsamp)) + nsamp = 150; + endif + + tit = sprintf ("sampled at %d random points", nsamp); + disp (tit); + % create random samples + xs = rand (nsamp,1); ys = rand (nsamp,1); + xs = 6*xs-3; ys = 6*ys - 3; + % evaluate at random samples + zs = testfun1 (xs, ys); + xys = [xs ys]; -if (!exist("nsamp","var") || !isnumeric(nsamp)) - nsamp = 150 -end -tit = sprintf("sampled at %d random points",nsamp); -disp(tit); -% create random samples -xs = rand(nsamp,1); ys = rand(nsamp,1); -xs = 6*xs-3; ys = 6*ys - 3; -% evaluate at random samples -zs = testfun(xs,ys); -xys = [xs ys]; + subplot (2, 2, 2); + plot3 (xs, ys, zs, ".+"); + title (tit); + subplot (2, 2, 4); + plot (xs, ys, ".+"); + pause -subplot(1,2,2); -plot3(xs,ys,zs,".+"); -title(tit); -pause; + tit = "GPR model with heuristic hypers"; + disp (tit); + ths = 1 ./ std (xys); + GPM = gpr_train (xys, zs, ths, 1e-5); + zm = gpr_predict (GPM, [vec(xi) vec(yi)]); + zm = reshape (zm, size(zi)); + zm = min (zm, zimax); zm = max (zm, zimin); + subplot (2, 2, 2); + mesh (xi, yi, zm); + title (tit); + subplot(2, 2, 4) + hold on + contourf (xi, yi, zm, 20); + plot (xs, ys, "+6"); + hold off + pause -tit = "GPR model with heuristic hypers"; -disp(tit); -ths = 1 ./ std(xys); -GPM = gpr_train(xys,zs,ths,1e-5); -zm = gpr_predict(GPM,[vec(xi) vec(yi)]); -zm = reshape(zm,size(zi)); -zm = min(zm,zimax); zm = max(zm,zimin); -subplot(1,2,2); -mesh(xi,yi,zm); -title(tit); -pause; + tit = "GPR model with MLE training"; + disp (tit); + fflush (stdout); + GPM = gpr_train (xys, zs, ths, 1e-5, {"tol", 1e-5, "maxev", 400, "numin", 1e-8}); + zm = gpr_predict (GPM, [vec(xi) vec(yi)]); + zm = reshape (zm, size (zi)); + zm = min (zm, zimax); zm = max (zm, zimin); + subplot (2, 2, 2); + mesh (xi, yi, zm); + title (tit); + subplot(2, 2, 4) + hold on + contourf (xi, yi, zm, 20); + plot (xs, ys, "+6"); + hold off + pause + + close +endfunction + +function demo_octgpr2 (ncnt = 50, npt = 500) -tit = "GPR model with MLE training"; -disp(tit); -fflush(stdout); -GPM = gpr_train(xys,zs,ths,1e-5,{"tol",1e-5,"maxev",400,"numin",1e-8}); -zm = gpr_predict(GPM,[vec(xi) vec(yi)]); -zm = reshape(zm,size(zi)); -zm = min(zm,zimax); zm = max(zm,zimin); -subplot(1,2,2); -mesh(xi,yi,zm); -title(tit); -pause; + npt = ncnt*ceil (npt/ncnt); + U = rand (ncnt, 2); + cs = min (pdist2_mw (U, 2) + diag (Inf (ncnt, 1))); + X = repmat (U, npt/ncnt, 1) + repmat (cs', npt/ncnt, 2) .* randn (npt, 2); + disp ("slightly clustered random points") + plot (X(:,1), X(:,2), "+"); + pause + + [U, ur] = rbf_centers(X, ncnt); -close + fi = linspace (0, 2*pi, 20); + ncolors = rows (colormap); + hold on + for i = 1:rows (U) + xc = U(i,1) + ur(i) * cos (fi); + yc = U(i,2) + ur(i) * sin (fi); + line (xc, yc); + endfor + hold off + pause + close + +endfunction
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/main/octgpr/inst/rbf_centers.m Thu Apr 17 13:53:44 2008 +0000 @@ -0,0 +1,97 @@ +% Copyright (C) 2008 VZLU Prague, a.s., Czech Republic +% +% Author: Jaroslav Hajek <highegg@gmail.com> +% +% This file is part of OctGPR. +% +% OctGPR is free software; you can redistribute it and/or modify +% it under the terms of the GNU General Public License as published by +% the Free Software Foundation; either version 2 of the License, or +% (at your option) any later version. +% +% This program is distributed in the hope that it will be useful, +% but WITHOUT ANY WARRANTY; without even the implied warranty of +% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +% GNU General Public License for more details. +% +% You should have received a copy of the GNU General Public License +% along with this software; see the file COPYING. If not, see +% <http://www.gnu.org/licenses/>. +% + +% -*- texinfo -*- +% @deftypefn {Function File} {[U, ur, iu]} = rbf_centers (@var{X}, @var{nu}, @var{theta}) +% Selects a given number of RBF centers based on Lloyd's clustering algorithm. +% +% @end deftypefn +function [U, ur, iu] = rbf_centers (X, nu, theta) + + pso_old = page_screen_output (0); + + if (nargin == 3) + X = dmult (X, theta); + elseif (nargin != 2) + print_usage (); + endif + + disp ("initializing ..."); + % the D^2 weighting initialization + + D = Inf; + kk = 1:rows (X); + cp = kk; + + for i = 1:nu + jj = sum (rand() * cp(end) < cp); + k(i) = kk(jj); + kk(jj) = []; + U = X(k(i),:); + D = min (D, pdist2_mw(X, U, 'ssq')'); + cp = cumsum (D(kk)); + endfor + + + % now perform the k-means algorithm + + U = X(k,:); + D = pdist2_mw(U, X, 'ssq'); + [xx, j] = min (D); + + it = 0; + do + for i = 1:nu + ij = find(j == i); + if (!isempty (ij)) + U(i,:) = mean (X(ij,:)); + else + U(i,:) = X(ceil (rand () * rows (X)), :); + endif + endfor + j1 = j; + D = pdist2_mw (U, X, 'ssq'); + [xx, j] = min (D); + printf ("k-means iteration %d\r", ++it); + fflush (stdout); + until (all (j == j1)) + printf ("\n"); + + if (nargout > 2) + iu = j; + endif + + if (nargout > 1) + ur = zeros (nu, 1); + for i = 1:nu + ij = (j == i); + ur(i) = sqrt (max (D(i,ij))); + endfor + endif + + if (nargin == 3) + U = dmult (U, 1./theta); + if (any(theta == 0)) + U(:,theta == 0) = 0; + endif + endif + +endfunction
--- a/main/octgpr/src/ChangeLog Thu Apr 17 08:09:27 2008 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,68 +0,0 @@ -2008-04-04 Jaroslav Hajek <highegg@gmail.com> - - * nllgpr.f: use QR decomposition instead of normal equations - for mean fitting. - * train.c, setup.c: change workspace size for nllgpr - * configure.in: update the LAPACK routines checked - * configure: regenerate - -2008-03-26 Jaroslav Hajek <highegg@gmail.com> - - * config.log, config.status, config.h, Makefile: remove. - * configure.in: add support for -fPIC flags. - * configure, config.h.in: regenerate. - * Makefile.in: add distclean target - -2008-02-28 Jaroslav Hajek <highegg@gmail.com> - - * corrf.f: added Matern-3/2 and Matern-5/2 correlation funcs - * get_corrf.c, forsubs.h: included the new correlation funcs - * gpr_train.cc: documentation for the new correlation funcs - -2008-02-24 Jaroslav Hajek <highegg@gmail.com> - - * optdrv.f: implement objective reduction stopping criterion - * gprmod.h, train.c, gpr_train.cc: add support for the new feature - * forsubs.h nl0gpr.f: correct subroutine name - * train.c: correct CP array size - -2008-02-19 Jaroslav Hajek <highegg@gmail.com> - - * *.h *c *.cc: adjusted C/C++ sources to meet GNU coding standards - -2008-02-18 Jaroslav Hajek <highegg@gmail.com> - - * forsubs.h gprmod.h predict.c train.c setup.c: added const modifiers - where appropriate. - * grp_train, gpr_predict: replaced Array::fortran_vec with Array::data - where appropriate, to prevent unnecessary copying. - -2008-02-15 Jaroslav Hajek <highegg@gmail.com> - - * octgpr.m: improved demo. - -2008-02-13 Jaroslav Hajek <highegg@gmail.com> - - * train.c: correct allocation and copying result values. - * nldgpr.f: compute 2nd derivative by a better formula. - * gpr_train.cc: improved documentation - * gpr_predict.cc: improved documentation - -2008-02-12 Jaroslav Hajek <highegg@gmail.com> - - * dscrot.f: correct the W*D*x and W'*D*x sequence. - * optdrv.f: correct calls to dscrot.f - -2008-02-08 Jaroslav Hajek <highegg@gmail.com> - - * train.c: corrected C->Fortran calls in to use proper #defines from forsubs.h - * vmfac.f vmcmp.f trstep.f: deleted (no use) - * optdrv.f: cosmetic changes - * ../inst/demo_octgpr.m: added copyright - * ../TODO: modified - -2008-02-07 Jaroslav Hajek <highegg@gmail.com> - - * gpr_train.cc: corrected inline Texinfo documentation - * ChangeLog: added this file -
--- a/main/octgpr/src/Makefile.in Thu Apr 17 08:09:27 2008 +0000 +++ b/main/octgpr/src/Makefile.in Thu Apr 17 13:53:44 2008 +0000 @@ -35,7 +35,7 @@ OBJS_GPR_PRED=dsdacc.o dwdis2.o infgpr.o corrf.o \ get_corrf.o predict.o -all: gpr_train.oct gpr_predict.oct +all: gpr_train.oct gpr_predict.oct pdist2_mw.oct %.o: %.f $(F77) $(FFLAGS) -c $< @@ -51,6 +51,10 @@ $(MKOCTFILE) -o $@ gpr_train.o $(OBJS_GPR_TRAIN) $(LIBS) gpr_predict.oct: gpr_predict.o $(OBJS_GPR_PRED) $(MKOCTFILE) -o $@ gpr_predict.o $(OBJS_GPR_PRED) $(LIBS) + +pdist2_mw.oct: pdist2_mw.cc + $(MKOCTFILE) -o $@ $< + clean: rm -f *.o *.oct distclean: clean
--- a/main/octgpr/src/configure.in Thu Apr 17 08:09:27 2008 +0000 +++ b/main/octgpr/src/configure.in Thu Apr 17 13:53:44 2008 +0000 @@ -1,4 +1,4 @@ -s# Copyright (C) 2008 VZLU Prague, a.s., Czech Republic +# Copyright (C) 2008 VZLU Prague, a.s., Czech Republic # # Author: Jaroslav Hajek <highegg@gmail.com> #
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/main/octgpr/src/pdist2_mw.cc Thu Apr 17 13:53:44 2008 +0000 @@ -0,0 +1,287 @@ +/* Copyright (C) 2008 VZLU Prague, a.s., Czech Republic + * + * Author: Jaroslav Hajek <highegg@gmail.com> + * + * This file is part of OctGPR. + * + * OctGPR is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this software; see the file COPYING. If not, see + * <http://www.gnu.org/licenses/>. */ + +#include <cmath> +#include <oct.h> +#include <oct-cmplx.h> + +inline double abs2(double x) +{ + return x*x; +} +inline double abs2(Complex x) +{ + return abs2(x. real()) + abs2(x.imag ()); +} + +/* distance functors */ +template<typename T> +struct distfun_eu +{ + double operator() (octave_idx_type dim, const T *x, const T* y) + { + double d = 1, scl = 0; + for (octave_idx_type i = 0; i < dim; i++) + { + double t = std::abs (x[i]-y[i]); + if (scl < t) + { + d *= abs2 (scl/t); + d += 1; + scl = t; + } + else if (t != 0) + d += abs2 (t/scl); + } + return scl * std::sqrt (d); + } +}; + +template<typename T> +struct distfun_sqeu +{ + double operator() (octave_idx_type dim, const T *x, const T* y) + { + double d = 0; + for (octave_idx_type i = 0; i < dim; i++) + d += abs2(x[i]-y[i]); + return d; + } +}; + +template<typename T> +struct distfun_l1 +{ + double operator() (octave_idx_type dim, const T *x, const T* y) + { + double d = 0; + for (octave_idx_type i = 0; i < dim; i++) + d += std::abs(x[i]-y[i]); + return d; + } +}; + +template<typename T> +struct distfun_max +{ + double operator() (octave_idx_type dim, const T *x, const T* y) + { + double d = 0; + for (octave_idx_type i = 0; i < dim; i++) + { + double t = std::abs(x[i]-y[i]); + if (t > d) d = t; + } + return d; + } +}; + +template<typename T> +struct distfun_mw +{ + double p; + distfun_mw (double _p) : p(_p) {} + double operator() (octave_idx_type dim, const T *x, const T* y) + { + double d = 1, scl = 0; + for (octave_idx_type i = 0; i < dim; i++) + { + double t = std::abs (x[i]-y[i]); + if (scl < t) + { + d *= std::pow (scl/t, p); + d += 1; + scl = t; + } + else if (t != 0) + d += std::pow (t/scl, p); + } + return scl * std::pow (d, 1/p); + } +}; + +template<typename T, class distfun> +void +fill_dist_matrix (octave_idx_type dim, octave_idx_type nx, octave_idx_type ny, + const T *X, const T* Y, double *D, distfun df) +{ + octave_idx_type i,j; + for (j = 0; j < ny; j++) + { + const T *PX = X; + for (i = 0; i < nx; i++) + { + *(D++) = df (dim, PX, Y); + PX += dim; + } + Y += dim; + } +} + +template<typename T, class distfun> +void +fill_dist_matrix (octave_idx_type dim, octave_idx_type nx, + const T *X, double *D, distfun df) +{ + octave_idx_type i,j; + for (j = 0; j < nx; j++) + { + D += j; + for (i = -j; i < 0; i++) + D[i] = *(D + i*nx); + const T *PX = X; + for (i = j; i < nx; i++) + { + *(D++) = df (dim, PX, X); + PX += dim; + } + X += dim; + } +} + +template<typename T> +Matrix +get_dist_matrix (const MArray2<T>& X, bool ssq, double p = 0) +{ + Matrix D(X.rows (), X.rows ()); + + MArray2<T> XT = X.transpose (); + + if (ssq) + fill_dist_matrix (XT.rows (), XT.cols (), + XT.data (), D.fortran_vec (), distfun_sqeu<T> ()); + else if (p == 2) + fill_dist_matrix (XT.rows (), XT.cols (), + XT.data (), D.fortran_vec (), distfun_eu<T> ()); + else if (p == 1) + fill_dist_matrix (XT.rows (), XT.cols (), + XT.data (), D.fortran_vec (), distfun_l1<T> ()); + else if (xisinf (p)) + fill_dist_matrix (XT.rows (), XT.cols (), + XT.data (), D.fortran_vec (), distfun_max<T> ()); + else + fill_dist_matrix (XT.rows (), XT.cols (), + XT.data (), D.fortran_vec (), distfun_mw<T> (p)); + return D; + +} + +template<typename T> +Matrix +get_dist_matrix (const MArray2<T>& X, const MArray2<T>& Y, bool ssq, double p = 0) +{ + Matrix D(X.rows (), Y.rows ()); + + MArray2<T> XT = X.transpose (), YT = Y.transpose (); + + if (ssq) + fill_dist_matrix (XT.rows (), XT.cols (), YT.cols (), + XT.data (), YT.data (), D.fortran_vec (), distfun_sqeu<T> ()); + else if (p == 2) + fill_dist_matrix (XT.rows (), XT.cols (), YT.cols (), + XT.data (), YT.data (), D.fortran_vec (), distfun_eu<T> ()); + else if (p == 1) + fill_dist_matrix (XT.rows (), XT.cols (), YT.cols (), + XT.data (), YT.data (), D.fortran_vec (), distfun_l1<T> ()); + else if (xisinf (p)) + fill_dist_matrix (XT.rows (), XT.cols (), YT.cols (), + XT.data (), YT.data (), D.fortran_vec (), distfun_max<T> ()); + else + fill_dist_matrix (XT.rows (), XT.cols (), YT.cols (), + XT.data (), YT.data (), D.fortran_vec (), distfun_mw<T> (p)); + return D; + +} + +DEFUN_DLD(pdist2_mw,args,, +"-*- texinfo -*-\n\ +@deftypefn {Loadable Function} @var{D} = pdist2_mw (@var{X}, @var{Y}, @var{p})\n\ +@deftypefnx {Loadable Function} @var{D} = pdist2_mw (@var{X}, @var{p})\n\ +Assembles a pairwise minkowski-distance matrix for two given sets of points.\n\ +@var{X} and @var{Y} should be real or complex matrices with a point per row,\n\ +so numbers of columns must match. The matrix contains the pairwise\n\ +distances @code{D(i,j) = norm(X(i,:)-Y(j,:),P)}.\n\ +@var{p} can also be the string \'ssq\' requesting squared euclidean distance.\n\ +(not a metric, but often useful and faster than @code{@var{p}=2})\n\ +If @var{Y} is not given, a symmetric distance matrix is calculated efficiently.\n\ +@seealso{norm}\n\ +@end deftypefn") +{ + int nargin = args.length(); + octave_value_list retval; + + if (nargin < 2 || nargin > 3) + { + print_usage (); + return retval; + } + + octave_value argx = args(0), argy, argp; + bool sym = false; + + if (nargin > 2) + argy = args(1); + else + { + argy = argx; + sym = true; + } + + if (nargin > 2) + argp = args(2); + else + argp = args(1); + + bool ssq = (argp.is_string () && argp.string_value () == "ssq"); + + if (argx.is_matrix_type () && argy.is_matrix_type () && (ssq || argp.is_real_scalar ())) + { + double p = ssq ? 0 : argp.scalar_value (); + + if (argx.columns () == argy.columns ()) + { + if (argx.is_real_matrix () && argy.is_real_matrix ()) + { + if (sym) + retval(0) = get_dist_matrix (argx.matrix_value (), + ssq, p); + else + retval(0) = get_dist_matrix (argx.matrix_value (), + argy.matrix_value (), + ssq, p); + } + else + { + if (sym) + retval(0) = get_dist_matrix (argx.complex_matrix_value (), + ssq, p); + else + retval(0) = get_dist_matrix (argx.complex_matrix_value (), + argy.complex_matrix_value (), + ssq, p); + } + } + else + error ("pmwdmat: dimension mismatch"); + } + else + error ("pmwdmat: X and Y should be matrices, p a real scalar"); + +}