Mercurial > jwe > octave
changeset 29293:512591ccb174
interpn.m: Overhaul function.
* interpn.m: Use @code rather than @qcode for 'NA' value in documentation.
Move initial input validation checking number of arguments to head of file.
Add comments to help explain what each block of code is doing.
Call validatestring() with third argument so that it reports originating
function. Fix incorrect input validation when nargs <= 2.
Replace two calls to size_equal() with one for better performance.
Make error messages for input validation more specific.
Recode to eliminate unnecessary temporary variable "foobar".
Use false() rather than zeros() when creating logical indices to save 7/8
memory. Add input validation for unimplemented METHOD "pchip".
Remove unnecessary error message for unrecognized interpolation METHOD.
Add input validation BIST tests.
* __splinen__.m: Remove unnecessary comments. Correct spelling in
FIXME comment. Use cellfun with accelerated functions for a 2/3rds
speed improvement in determining whether non-vector arrays are present.
Move ndgrid() call on XI in to if block that is only executed when
EXTRAPVAL is present. Use false() rather than zeros() when creating logical
indices to save 7/8 memory.
author | Rik <rik@octave.org> |
---|---|
date | Tue, 12 Jan 2021 17:15:53 -0800 |
parents | 2059b340e924 |
children | 22215ea74b84 |
files | scripts/general/interpn.m scripts/general/private/__splinen__.m |
diffstat | 2 files changed, 57 insertions(+), 42 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/general/interpn.m Tue Jan 12 08:47:18 2021 -0800 +++ b/scripts/general/interpn.m Tue Jan 12 17:15:53 2021 -0800 @@ -80,20 +80,21 @@ ## must be specified as well. If @var{extrapval} is omitted and the ## @var{method} is @qcode{"spline"}, then the extrapolated values of the ## @qcode{"spline"} are used. Otherwise the default @var{extrapval} value for -## any other @var{method} is @qcode{"NA"}. +## any other @var{method} is @code{NA}. ## @seealso{interp1, interp2, interp3, spline, ndgrid} ## @end deftypefn function vi = interpn (varargin) + if (nargin < 1 || ! isnumeric (varargin{1})) + print_usage (); + endif + method = "linear"; extrapval = []; nargs = nargin; - if (nargin < 1 || ! isnumeric (varargin{1})) - print_usage (); - endif - + ## Find and validate EXTRAPVAL and/or METHOD inputs if (nargs > 1 && ischar (varargin{end-1})) if (! isnumeric (varargin{end}) || ! isscalar (varargin{end})) error ("interpn: EXTRAPVAL must be a numeric scalar"); @@ -110,18 +111,18 @@ warning ("interpn: ignoring unsupported '*' flag to METHOD"); method(1) = []; endif - method = validatestring (method, ... - {"nearest", "linear", "pchip", "cubic", "spline"}); + method = tolower (method); + method = validatestring (method, + {"nearest", "linear", "pchip", "cubic", "spline"}, + "interpn"); - if (nargs < 3) + if (nargs <= 2) + ## Calling form interpn (V, ...) v = varargin{1}; m = 1; if (nargs == 2) - if (ischar (varargin{2})) - method = varargin{2}; - elseif (isnumeric (m) && isscalar (m) && fix (m) == m) - m = varargin{2}; - else + m = varargin{2}; + if (! (isnumeric (m) && isscalar (m) && m == fix (m))) print_usage (); endif endif @@ -136,6 +137,7 @@ y{1} = y{1}.'; [y{:}] = ndgrid (y{:}); elseif (! isvector (varargin{1}) && nargs == (ndims (varargin{1}) + 1)) + ## Calling form interpn (V, Y1, Y2, ...) v = varargin{1}; sz = size (v); nd = ndims (v); @@ -146,6 +148,7 @@ endfor elseif (rem (nargs, 2) == 1 && nargs == (2 * ndims (varargin{ceil (nargs / 2)})) + 1) + ## Calling form interpn (X1, X2, ..., V, Y1, Y2, ...) nv = ceil (nargs / 2); v = varargin{nv}; sz = size (v); @@ -157,26 +160,20 @@ endif if (any (! cellfun ("isvector", x))) - for i = 2 : nd - if (! size_equal (x{1}, x{i}) || ! size_equal (x{i}, v)) - error ("interpn: dimensional mismatch"); + for i = 1 : nd + if (! size_equal (x{i}, v)) + error ("interpn: incorrect dimensions for input X%d", i); endif idx(1 : nd) = {1}; idx(i) = ":"; x{i} = x{i}(idx{:})(:); endfor - idx(1 : nd) = {1}; - idx(1) = ":"; - x{1} = x{1}(idx{:})(:); endif - method = tolower (method); - all_vectors = all (cellfun ("isvector", y)); same_size = size_equal (y{:}); if (all_vectors && ! same_size) - [foobar(1:numel(y)).y] = ndgrid (y{:}); - y = {foobar.y}; + [y{:}] = ndgrid (y{:}); endif if (strcmp (method, "linear")) @@ -186,19 +183,25 @@ endif vi(isna (vi)) = extrapval; elseif (strcmp (method, "nearest")) + ## FIXME: This seems overly complicated. Is there a way to simplify + ## all the code after the call to lookup (which should be fast)? + ## Could Qhull be used for quick nearest neighbor calculation? yshape = size (y{1}); yidx = cell (1, nd); + ## Find rough nearest index using lookup function [O(log2 (N)]. for i = 1 : nd y{i} = y{i}(:); yidx{i} = lookup (x{i}, y{i}, "lr"); endfor + ## Single comparison to next largest index to see which is closer. idx = cell (1,nd); for i = 1 : nd idx{i} = yidx{i} ... + (y{i} - x{i}(yidx{i})(:) >= x{i}(yidx{i} + 1)(:) - y{i}); endfor vi = v(sub2ind (sz, idx{:})); - idx = zeros (prod (yshape), 1); + ## Apply EXTRAPVAL to points outside original volume. + idx = false (prod (yshape), 1); for i = 1 : nd idx |= y{i} < min (x{i}(:)) | y{i} > max (x{i}(:)); endfor @@ -209,17 +212,15 @@ vi = reshape (vi, yshape); elseif (strcmp (method, "spline")) if (any (! cellfun ("isvector", y))) - for i = 2 : nd - if (! size_equal (y{1}, y{i})) - error ("interpn: dimensional mismatch"); + ysz = size (y{1}); + for i = 1 : nd + if (any (size (y{i}) != ysz)) + error ("interpn: incorrect dimensions for input Y%d", i); endif idx(1 : nd) = {1}; idx(i) = ":"; y{i} = y{i}(idx{:}); endfor - idx(1 : nd) = {1}; - idx(1) = ":"; - y{1} = y{1}(idx{:}); endif vi = __splinen__ (x, v, y, extrapval, "interpn"); @@ -235,10 +236,10 @@ vi = vi(cellfun (@(x) sub2ind (size (vi), x{:}), idx)); vi = reshape (vi, size (y{1})); endif + elseif (strcmp (method, "pchip")) + error ("interpn: pchip interpolation not yet implemented"); elseif (strcmp (method, "cubic")) error ("interpn: cubic interpolation not yet implemented"); - else - error ("interpn: unrecognized interpolation METHOD"); endif endfunction @@ -353,4 +354,20 @@ %! assert (interpn (z, "spline"), zout, tol); ## Test input validation +%!error <Invalid call> interpn () +%!error <Invalid call> interpn ("foobar") +%!error <EXTRAPVAL must be a numeric scalar> interpn (1, "linear", {1}) +%!error <EXTRAPVAL must be a numeric scalar> interpn (1, "linear", [1, 2]) %!warning <ignoring unsupported '\*' flag> interpn (rand (3,3), 1, "*linear"); +%!error <'foobar' does not match any of> interpn (1, "foobar") +%!error <wrong number or incorrectly formatted input arguments> +%! interpn (1, 2, 3, 4); +%!error <incorrect dimensions for input X1> +%! interpn ([1,2], ones (2,2), magic (3), [1,2], [1,2]) +%!error <incorrect dimensions for input X2> +%! interpn (ones (3,3), ones (2,2), magic (3), [1,2], [1,2]) +%!error <incorrect dimensions for input Y2> +%! interpn ([1,2], [1,2], magic (3), [1,2], ones (2,2), "spline") +%!error <pchip interpolation not yet implemented> interpn ([1,2], "pchip") +%!error <cubic interpolation not yet implemented> interpn ([1,2], "cubic") +
--- a/scripts/general/private/__splinen__.m Tue Jan 12 08:47:18 2021 -0800 +++ b/scripts/general/private/__splinen__.m Tue Jan 12 17:15:53 2021 -0800 @@ -23,32 +23,30 @@ ## ######################################################################## -## Undocumented internal function. - ## -*- texinfo -*- ## @deftypefn {} {@var{yi} =} __splinen__ (@var{x}, @var{y}, @var{xi}, @var{extrapval}, @var{f}) ## Undocumented internal function. ## @end deftypefn -## FIXME: Allow arbitrary grids.. +## FIXME: Allow arbitrary grids. function yi = __splinen__ (x, y, xi, extrapval, f) - ## ND isvector function. - isvec = @(x) numel (x) == length (x); - if (! iscell (x) || length (x) < ndims (y) || any (! cellfun (isvec, x)) - || ! iscell (xi) || length (xi) < ndims (y) - || any (! cellfun (isvec, xi))) + ## ND function to check whether any object in cell array is *not* a vector. + isnotvec = @(x) cellfun ("numel", x) != cellfun ("length", x); + if (! iscell (x) || length (x) < ndims (y) || any (isnotvec (x)) + || ! iscell (xi) || length (xi) < ndims (y) || any (isnotvec (xi))) error ("__splinen__: %s: non-gridded data or dimensions inconsistent", f); endif + yi = y; for i = length (x):-1:1 yi = permute (spline (x{i}, yi, xi{i}(:)), [length(x),1:length(x)-1]); endfor - [xi{:}] = ndgrid (cellfun (@(x) x(:), xi, "uniformoutput", false){:}); if (! isempty (extrapval)) - idx = zeros (size (xi{1})); + [xi{:}] = ndgrid (cellfun (@(x) x(:), xi, "uniformoutput", false){:}); + idx = false (size (xi{1})); for i = 1 : length (x) idx |= xi{i} < min (x{i}(:)) | xi{i} > max (x{i}(:)); endfor