changeset 30688:d1fe2cb16d95

Fix __lin_interpn__ for complex input (bug #61907). * libinterp/corefcn/__lin_interp__.cc (lin_interpn<T, DT>, lin_interpn<MT, DMT, DT>): Generalize implementation for complex valued input. * scripts/general/interpn.m: Remove work-around.
author Christof Kaufmann <christofkaufmann.dev@gmail.com>
date Wed, 26 Jan 2022 14:42:09 +0100
parents 97989152bfbe
children 4b367bf5eb16
files libinterp/corefcn/__lin_interpn__.cc scripts/general/interpn.m
diffstat 2 files changed, 51 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/__lin_interpn__.cc	Sat Jan 29 16:24:40 2022 -0800
+++ b/libinterp/corefcn/__lin_interpn__.cc	Wed Jan 26 14:42:09 2022 +0100
@@ -35,6 +35,8 @@
 #include "error.h"
 #include "ovl.h"
 
+#include <type_traits>
+
 OCTAVE_NAMESPACE_BEGIN
 
 // equivalent to isvector.m
@@ -127,11 +129,11 @@
 
 // n-dimensional linear interpolation
 
-template <typename T>
+template <typename T, typename DT>
 void
 lin_interpn (int n, const octave_idx_type *size, const octave_idx_type *scale,
-             octave_idx_type Ni, T extrapval, const T **x,
-             const T *v, const T **y, T *vi)
+             octave_idx_type Ni, DT extrapval, const T **x,
+             const DT *v, const T **y, DT *vi)
 {
   bool out = false;
   int bit;
@@ -185,13 +187,17 @@
     }
 }
 
-template <typename T, typename M>
+template <typename MT, typename DMT, typename DT>
 octave_value
-lin_interpn (int n, M *X, const M V, M *Y)
+lin_interpn (int n, MT *X, const DMT V, MT *Y, DT extrapval)
 {
+  static_assert(std::is_same<DT, typename DMT::element_type>::value,
+                "Type DMT must be an ArrayType with elements of type DT.");
+  using T = typename MT::element_type;
+
   octave_value retval;
 
-  M Vi = M (Y[0].dims ());
+  DMT Vi = DMT (Y[0].dims ());
 
   OCTAVE_LOCAL_BUFFER (const T *, y, n);
   OCTAVE_LOCAL_BUFFER (octave_idx_type, size, n);
@@ -205,12 +211,10 @@
   OCTAVE_LOCAL_BUFFER (const T *, x, n);
   OCTAVE_LOCAL_BUFFER (octave_idx_type, scale, n);
 
-  const T *v = V.data ();
-  T *vi = Vi.fortran_vec ();
+  const DT *v = V.data ();
+  DT *vi = Vi.fortran_vec ();
   octave_idx_type Ni = Vi.numel ();
 
-  T extrapval = octave_NA;
-
   // offset in memory of each dimension
 
   scale[0] = 1;
@@ -228,7 +232,7 @@
           if (X[i].dims () != V.dims ())
             error ("interpn: incompatible size of argument number %d", i+1);
 
-          M tmp = M (dim_vector (size[i], 1));
+          MT tmp = MT (dim_vector (size[i], 1));
 
           for (octave_idx_type j = 0; j < size[i]; j++)
             tmp(j) = X[i](scale[i]*j);
@@ -284,8 +288,6 @@
       OCTAVE_LOCAL_BUFFER (FloatNDArray, X, n);
       OCTAVE_LOCAL_BUFFER (FloatNDArray, Y, n);
 
-      const FloatNDArray V = args(n).float_array_value ();
-
       for (int i = 0; i < n; i++)
         {
           X[i] = args(i).float_array_value ();
@@ -295,15 +297,24 @@
             error ("interpn: incompatible size of argument number %d", n+i+2);
         }
 
-      retval = lin_interpn<float, FloatNDArray> (n, X, V, Y);
+      if (args(n).iscomplex ())
+        {
+          const FloatComplexNDArray V = args(n).float_complex_array_value ();
+          FloatComplex extrapval (octave_NA, octave_NA);
+          retval = lin_interpn (n, X, V, Y, extrapval);
+        }
+      else
+        {
+          const FloatNDArray V = args(n).float_array_value ();
+          float extrapval = octave_NA;
+          retval = lin_interpn (n, X, V, Y, extrapval);
+        }
     }
   else
     {
       OCTAVE_LOCAL_BUFFER (NDArray, X, n);
       OCTAVE_LOCAL_BUFFER (NDArray, Y, n);
 
-      const NDArray V = args(n).array_value ();
-
       for (int i = 0; i < n; i++)
         {
           X[i] = args(i).array_value ();
@@ -313,15 +324,36 @@
             error ("interpn: incompatible size of argument number %d", n+i+2);
         }
 
-      retval = lin_interpn<double, NDArray> (n, X, V, Y);
+      if (args(n).iscomplex ())
+        {
+          const ComplexNDArray V = args(n).complex_array_value ();
+          Complex extrapval (octave_NA, octave_NA);
+          retval = lin_interpn (n, X, V, Y, extrapval);
+        }
+      else
+        {
+          const NDArray V = args(n).array_value ();
+          double extrapval = octave_NA;
+          retval = lin_interpn (n, X, V, Y, extrapval);
+        }
     }
 
   return retval;
 }
 
 /*
-## No test needed for internal helper function.
-%!assert (1)
+## Test that real and imaginary parts are interpolated the same way
+## and outer points are set to NA + 1i*NA
+%!test <*61907>
+%! x1 = 1:3;
+%! x2 = 1:4;
+%! v = repmat(1:4, 3, 1) + 1i * repmat((1:3)', 1, 4);
+%! [XI2, XI1] = meshgrid(1.5:3.5, 1.5:3.5);
+%! vi_complex = __lin_interpn__ (x1, x2, v, XI1, XI2);
+%! vi_real = __lin_interpn__ (x1, x2, real (v), XI1, XI2);
+%! vi_imag = __lin_interpn__ (x1, x2, imag (v), XI1, XI2);
+%! assert (real (vi_complex), vi_real);
+%! assert (imag (vi_complex), vi_imag);
 */
 
 OCTAVE_NAMESPACE_END
--- a/scripts/general/interpn.m	Sat Jan 29 16:24:40 2022 -0800
+++ b/scripts/general/interpn.m	Wed Jan 26 14:42:09 2022 +0100
@@ -185,11 +185,6 @@
 
   if (strcmp (method, "linear"))
     vi = __lin_interpn__ (x{:}, v, y{:});
-    if (iscomplex (v))
-      ## __lin_interpn__ ignores imaginary part. Do it again for imag part.
-      ## FIXME: Adapt __lin_interpn__ to correctly handle complex input.
-      vi += 1i * __lin_interpn__ (x{:}, imag (v), y{:});
-    endif
     vi(isna (vi)) = extrapval;
   elseif (strcmp (method, "nearest"))
     ## FIXME: This seems overly complicated.  Is there a way to simplify