changeset 29293:512591ccb174

interpn.m: Overhaul function. * interpn.m: Use @code rather than @qcode for 'NA' value in documentation. Move initial input validation checking number of arguments to head of file. Add comments to help explain what each block of code is doing. Call validatestring() with third argument so that it reports originating function. Fix incorrect input validation when nargs <= 2. Replace two calls to size_equal() with one for better performance. Make error messages for input validation more specific. Recode to eliminate unnecessary temporary variable "foobar". Use false() rather than zeros() when creating logical indices to save 7/8 memory. Add input validation for unimplemented METHOD "pchip". Remove unnecessary error message for unrecognized interpolation METHOD. Add input validation BIST tests. * __splinen__.m: Remove unnecessary comments. Correct spelling in FIXME comment. Use cellfun with accelerated functions for a 2/3rds speed improvement in determining whether non-vector arrays are present. Move ndgrid() call on XI in to if block that is only executed when EXTRAPVAL is present. Use false() rather than zeros() when creating logical indices to save 7/8 memory.
author Rik <rik@octave.org>
date Tue, 12 Jan 2021 17:15:53 -0800
parents 2059b340e924
children 22215ea74b84
files scripts/general/interpn.m scripts/general/private/__splinen__.m
diffstat 2 files changed, 57 insertions(+), 42 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/general/interpn.m	Tue Jan 12 08:47:18 2021 -0800
+++ b/scripts/general/interpn.m	Tue Jan 12 17:15:53 2021 -0800
@@ -80,20 +80,21 @@
 ## must be specified as well.  If @var{extrapval} is omitted and the
 ## @var{method} is @qcode{"spline"}, then the extrapolated values of the
 ## @qcode{"spline"} are used.  Otherwise the default @var{extrapval} value for
-## any other @var{method} is @qcode{"NA"}.
+## any other @var{method} is @code{NA}.
 ## @seealso{interp1, interp2, interp3, spline, ndgrid}
 ## @end deftypefn
 
 function vi = interpn (varargin)
 
+  if (nargin < 1 || ! isnumeric (varargin{1}))
+    print_usage ();
+  endif
+
   method = "linear";
   extrapval = [];
   nargs = nargin;
 
-  if (nargin < 1 || ! isnumeric (varargin{1}))
-    print_usage ();
-  endif
-
+  ## Find and validate EXTRAPVAL and/or METHOD inputs
   if (nargs > 1 && ischar (varargin{end-1}))
     if (! isnumeric (varargin{end}) || ! isscalar (varargin{end}))
       error ("interpn: EXTRAPVAL must be a numeric scalar");
@@ -110,18 +111,18 @@
     warning ("interpn: ignoring unsupported '*' flag to METHOD");
     method(1) = [];
   endif
-  method = validatestring (method, ...
-    {"nearest", "linear", "pchip", "cubic", "spline"});
+  method = tolower (method);
+  method = validatestring (method,
+                           {"nearest", "linear", "pchip", "cubic", "spline"},
+                           "interpn");
 
-  if (nargs < 3)
+  if (nargs <= 2)
+    ## Calling form interpn (V, ...)
     v = varargin{1};
     m = 1;
     if (nargs == 2)
-      if (ischar (varargin{2}))
-        method = varargin{2};
-      elseif (isnumeric (m) && isscalar (m) && fix (m) == m)
-        m = varargin{2};
-      else
+      m = varargin{2};
+      if (! (isnumeric (m) && isscalar (m) && m == fix (m)))
         print_usage ();
       endif
     endif
@@ -136,6 +137,7 @@
     y{1} = y{1}.';
     [y{:}] = ndgrid (y{:});
   elseif (! isvector (varargin{1}) && nargs == (ndims (varargin{1}) + 1))
+    ## Calling form interpn (V, Y1, Y2, ...)
     v = varargin{1};
     sz = size (v);
     nd = ndims (v);
@@ -146,6 +148,7 @@
     endfor
   elseif (rem (nargs, 2) == 1
           && nargs == (2 * ndims (varargin{ceil (nargs / 2)})) + 1)
+    ## Calling form interpn (X1, X2, ..., V, Y1, Y2, ...)
     nv = ceil (nargs / 2);
     v = varargin{nv};
     sz = size (v);
@@ -157,26 +160,20 @@
   endif
 
   if (any (! cellfun ("isvector", x)))
-    for i = 2 : nd
-      if (! size_equal (x{1}, x{i}) || ! size_equal (x{i}, v))
-        error ("interpn: dimensional mismatch");
+    for i = 1 : nd
+      if (! size_equal (x{i}, v))
+        error ("interpn: incorrect dimensions for input X%d", i);
       endif
       idx(1 : nd) = {1};
       idx(i) = ":";
       x{i} = x{i}(idx{:})(:);
     endfor
-    idx(1 : nd) = {1};
-    idx(1) = ":";
-    x{1} = x{1}(idx{:})(:);
   endif
 
-  method = tolower (method);
-
   all_vectors = all (cellfun ("isvector", y));
   same_size = size_equal (y{:});
   if (all_vectors && ! same_size)
-    [foobar(1:numel(y)).y] = ndgrid (y{:});
-    y = {foobar.y};
+    [y{:}] = ndgrid (y{:});
   endif
 
   if (strcmp (method, "linear"))
@@ -186,19 +183,25 @@
     endif
     vi(isna (vi)) = extrapval;
   elseif (strcmp (method, "nearest"))
+    ## FIXME: This seems overly complicated.  Is there a way to simplify
+    ## all the code after the call to lookup (which should be fast)?
+    ## Could Qhull be used for quick nearest neighbor calculation?
     yshape = size (y{1});
     yidx = cell (1, nd);
+    ## Find rough nearest index using lookup function [O(log2 (N)].
     for i = 1 : nd
       y{i} = y{i}(:);
       yidx{i} = lookup (x{i}, y{i}, "lr");
     endfor
+    ## Single comparison to next largest index to see which is closer.
     idx = cell (1,nd);
     for i = 1 : nd
       idx{i} = yidx{i} ...
                + (y{i} - x{i}(yidx{i})(:) >= x{i}(yidx{i} + 1)(:) - y{i});
     endfor
     vi = v(sub2ind (sz, idx{:}));
-    idx = zeros (prod (yshape), 1);
+    ## Apply EXTRAPVAL to points outside original volume.
+    idx = false (prod (yshape), 1);
     for i = 1 : nd
       idx |= y{i} < min (x{i}(:)) | y{i} > max (x{i}(:));
     endfor
@@ -209,17 +212,15 @@
     vi = reshape (vi, yshape);
   elseif (strcmp (method, "spline"))
     if (any (! cellfun ("isvector", y)))
-      for i = 2 : nd
-        if (! size_equal (y{1}, y{i}))
-          error ("interpn: dimensional mismatch");
+      ysz = size (y{1});
+      for i = 1 : nd
+        if (any (size (y{i}) != ysz))
+          error ("interpn: incorrect dimensions for input Y%d", i);
         endif
         idx(1 : nd) = {1};
         idx(i) = ":";
         y{i} = y{i}(idx{:});
       endfor
-      idx(1 : nd) = {1};
-      idx(1) = ":";
-      y{1} = y{1}(idx{:});
     endif
 
     vi = __splinen__ (x, v, y, extrapval, "interpn");
@@ -235,10 +236,10 @@
       vi = vi(cellfun (@(x) sub2ind (size (vi), x{:}), idx));
       vi = reshape (vi, size (y{1}));
     endif
+  elseif (strcmp (method, "pchip"))
+    error ("interpn: pchip interpolation not yet implemented");
   elseif (strcmp (method, "cubic"))
     error ("interpn: cubic interpolation not yet implemented");
-  else
-    error ("interpn: unrecognized interpolation METHOD");
   endif
 
 endfunction
@@ -353,4 +354,20 @@
 %! assert (interpn (z, "spline"), zout, tol);
 
 ## Test input validation
+%!error <Invalid call> interpn ()
+%!error <Invalid call> interpn ("foobar")
+%!error <EXTRAPVAL must be a numeric scalar> interpn (1, "linear", {1})
+%!error <EXTRAPVAL must be a numeric scalar> interpn (1, "linear", [1, 2])
 %!warning <ignoring unsupported '\*' flag> interpn (rand (3,3), 1, "*linear");
+%!error <'foobar' does not match any of> interpn (1, "foobar")
+%!error <wrong number or incorrectly formatted input arguments>
+%! interpn (1, 2, 3, 4);
+%!error <incorrect dimensions for input X1>
+%! interpn ([1,2], ones (2,2), magic (3), [1,2], [1,2])
+%!error <incorrect dimensions for input X2>
+%! interpn (ones (3,3), ones (2,2), magic (3), [1,2], [1,2])
+%!error <incorrect dimensions for input Y2>
+%! interpn ([1,2], [1,2], magic (3), [1,2], ones (2,2), "spline")
+%!error <pchip interpolation not yet implemented> interpn ([1,2], "pchip")
+%!error <cubic interpolation not yet implemented> interpn ([1,2], "cubic")
+
--- a/scripts/general/private/__splinen__.m	Tue Jan 12 08:47:18 2021 -0800
+++ b/scripts/general/private/__splinen__.m	Tue Jan 12 17:15:53 2021 -0800
@@ -23,32 +23,30 @@
 ##
 ########################################################################
 
-## Undocumented internal function.
-
 ## -*- texinfo -*-
 ## @deftypefn {} {@var{yi} =} __splinen__ (@var{x}, @var{y}, @var{xi}, @var{extrapval}, @var{f})
 ## Undocumented internal function.
 ## @end deftypefn
 
-## FIXME: Allow arbitrary grids..
+## FIXME: Allow arbitrary grids.
 
 function yi = __splinen__ (x, y, xi, extrapval, f)
 
-  ## ND isvector function.
-  isvec = @(x) numel (x) == length (x);
-  if (! iscell (x) || length (x) < ndims (y) || any (! cellfun (isvec, x))
-      || ! iscell (xi) || length (xi) < ndims (y)
-      || any (! cellfun (isvec, xi)))
+  ## ND function to check whether any object in cell array is *not* a vector.
+  isnotvec = @(x) cellfun ("numel", x) != cellfun ("length", x);
+  if (! iscell (x) || length (x) < ndims (y) || any (isnotvec (x))
+      || ! iscell (xi) || length (xi) < ndims (y) || any (isnotvec (xi)))
     error ("__splinen__: %s: non-gridded data or dimensions inconsistent", f);
   endif
+
   yi = y;
   for i = length (x):-1:1
     yi = permute (spline (x{i}, yi, xi{i}(:)), [length(x),1:length(x)-1]);
   endfor
 
-  [xi{:}] = ndgrid (cellfun (@(x) x(:), xi, "uniformoutput", false){:});
   if (! isempty (extrapval))
-    idx = zeros (size (xi{1}));
+    [xi{:}] = ndgrid (cellfun (@(x) x(:), xi, "uniformoutput", false){:});
+    idx = false (size (xi{1}));
     for i = 1 : length (x)
       idx |= xi{i} < min (x{i}(:)) | xi{i} > max (x{i}(:));
     endfor