Mercurial > octave
changeset 30688:d1fe2cb16d95
Fix __lin_interpn__ for complex input (bug #61907).
* libinterp/corefcn/__lin_interp__.cc (lin_interpn<T, DT>,
lin_interpn<MT, DMT, DT>): Generalize implementation for complex valued input.
* scripts/general/interpn.m: Remove work-around.
author | Christof Kaufmann <christofkaufmann.dev@gmail.com> |
---|---|
date | Wed, 26 Jan 2022 14:42:09 +0100 |
parents | 97989152bfbe |
children | 4b367bf5eb16 |
files | libinterp/corefcn/__lin_interpn__.cc scripts/general/interpn.m |
diffstat | 2 files changed, 51 insertions(+), 24 deletions(-) [+] |
line wrap: on
line diff
--- a/libinterp/corefcn/__lin_interpn__.cc Sat Jan 29 16:24:40 2022 -0800 +++ b/libinterp/corefcn/__lin_interpn__.cc Wed Jan 26 14:42:09 2022 +0100 @@ -35,6 +35,8 @@ #include "error.h" #include "ovl.h" +#include <type_traits> + OCTAVE_NAMESPACE_BEGIN // equivalent to isvector.m @@ -127,11 +129,11 @@ // n-dimensional linear interpolation -template <typename T> +template <typename T, typename DT> void lin_interpn (int n, const octave_idx_type *size, const octave_idx_type *scale, - octave_idx_type Ni, T extrapval, const T **x, - const T *v, const T **y, T *vi) + octave_idx_type Ni, DT extrapval, const T **x, + const DT *v, const T **y, DT *vi) { bool out = false; int bit; @@ -185,13 +187,17 @@ } } -template <typename T, typename M> +template <typename MT, typename DMT, typename DT> octave_value -lin_interpn (int n, M *X, const M V, M *Y) +lin_interpn (int n, MT *X, const DMT V, MT *Y, DT extrapval) { + static_assert(std::is_same<DT, typename DMT::element_type>::value, + "Type DMT must be an ArrayType with elements of type DT."); + using T = typename MT::element_type; + octave_value retval; - M Vi = M (Y[0].dims ()); + DMT Vi = DMT (Y[0].dims ()); OCTAVE_LOCAL_BUFFER (const T *, y, n); OCTAVE_LOCAL_BUFFER (octave_idx_type, size, n); @@ -205,12 +211,10 @@ OCTAVE_LOCAL_BUFFER (const T *, x, n); OCTAVE_LOCAL_BUFFER (octave_idx_type, scale, n); - const T *v = V.data (); - T *vi = Vi.fortran_vec (); + const DT *v = V.data (); + DT *vi = Vi.fortran_vec (); octave_idx_type Ni = Vi.numel (); - T extrapval = octave_NA; - // offset in memory of each dimension scale[0] = 1; @@ -228,7 +232,7 @@ if (X[i].dims () != V.dims ()) error ("interpn: incompatible size of argument number %d", i+1); - M tmp = M (dim_vector (size[i], 1)); + MT tmp = MT (dim_vector (size[i], 1)); for (octave_idx_type j = 0; j < size[i]; j++) tmp(j) = X[i](scale[i]*j); @@ -284,8 +288,6 @@ OCTAVE_LOCAL_BUFFER (FloatNDArray, X, n); OCTAVE_LOCAL_BUFFER (FloatNDArray, Y, n); - const FloatNDArray V = args(n).float_array_value (); - for (int i = 0; i < n; i++) { X[i] = args(i).float_array_value (); @@ -295,15 +297,24 @@ error ("interpn: incompatible size of argument number %d", n+i+2); } - retval = lin_interpn<float, FloatNDArray> (n, X, V, Y); + if (args(n).iscomplex ()) + { + const FloatComplexNDArray V = args(n).float_complex_array_value (); + FloatComplex extrapval (octave_NA, octave_NA); + retval = lin_interpn (n, X, V, Y, extrapval); + } + else + { + const FloatNDArray V = args(n).float_array_value (); + float extrapval = octave_NA; + retval = lin_interpn (n, X, V, Y, extrapval); + } } else { OCTAVE_LOCAL_BUFFER (NDArray, X, n); OCTAVE_LOCAL_BUFFER (NDArray, Y, n); - const NDArray V = args(n).array_value (); - for (int i = 0; i < n; i++) { X[i] = args(i).array_value (); @@ -313,15 +324,36 @@ error ("interpn: incompatible size of argument number %d", n+i+2); } - retval = lin_interpn<double, NDArray> (n, X, V, Y); + if (args(n).iscomplex ()) + { + const ComplexNDArray V = args(n).complex_array_value (); + Complex extrapval (octave_NA, octave_NA); + retval = lin_interpn (n, X, V, Y, extrapval); + } + else + { + const NDArray V = args(n).array_value (); + double extrapval = octave_NA; + retval = lin_interpn (n, X, V, Y, extrapval); + } } return retval; } /* -## No test needed for internal helper function. -%!assert (1) +## Test that real and imaginary parts are interpolated the same way +## and outer points are set to NA + 1i*NA +%!test <*61907> +%! x1 = 1:3; +%! x2 = 1:4; +%! v = repmat(1:4, 3, 1) + 1i * repmat((1:3)', 1, 4); +%! [XI2, XI1] = meshgrid(1.5:3.5, 1.5:3.5); +%! vi_complex = __lin_interpn__ (x1, x2, v, XI1, XI2); +%! vi_real = __lin_interpn__ (x1, x2, real (v), XI1, XI2); +%! vi_imag = __lin_interpn__ (x1, x2, imag (v), XI1, XI2); +%! assert (real (vi_complex), vi_real); +%! assert (imag (vi_complex), vi_imag); */ OCTAVE_NAMESPACE_END
--- a/scripts/general/interpn.m Sat Jan 29 16:24:40 2022 -0800 +++ b/scripts/general/interpn.m Wed Jan 26 14:42:09 2022 +0100 @@ -185,11 +185,6 @@ if (strcmp (method, "linear")) vi = __lin_interpn__ (x{:}, v, y{:}); - if (iscomplex (v)) - ## __lin_interpn__ ignores imaginary part. Do it again for imag part. - ## FIXME: Adapt __lin_interpn__ to correctly handle complex input. - vi += 1i * __lin_interpn__ (x{:}, imag (v), y{:}); - endif vi(isna (vi)) = extrapval; elseif (strcmp (method, "nearest")) ## FIXME: This seems overly complicated. Is there a way to simplify