comparison libinterp/corefcn/__lin_interpn__.cc @ 30688:d1fe2cb16d95

Fix __lin_interpn__ for complex input (bug #61907). * libinterp/corefcn/__lin_interp__.cc (lin_interpn<T, DT>, lin_interpn<MT, DMT, DT>): Generalize implementation for complex valued input. * scripts/general/interpn.m: Remove work-around.
author Christof Kaufmann <christofkaufmann.dev@gmail.com>
date Wed, 26 Jan 2022 14:42:09 +0100
parents 796f54d4ddbf
children e88a07dec498
comparison
equal deleted inserted replaced
30687:97989152bfbe 30688:d1fe2cb16d95
33 33
34 #include "defun.h" 34 #include "defun.h"
35 #include "error.h" 35 #include "error.h"
36 #include "ovl.h" 36 #include "ovl.h"
37 37
38 #include <type_traits>
39
38 OCTAVE_NAMESPACE_BEGIN 40 OCTAVE_NAMESPACE_BEGIN
39 41
40 // equivalent to isvector.m 42 // equivalent to isvector.m
41 43
42 template <typename T> 44 template <typename T>
125 } 127 }
126 } 128 }
127 129
128 // n-dimensional linear interpolation 130 // n-dimensional linear interpolation
129 131
130 template <typename T> 132 template <typename T, typename DT>
131 void 133 void
132 lin_interpn (int n, const octave_idx_type *size, const octave_idx_type *scale, 134 lin_interpn (int n, const octave_idx_type *size, const octave_idx_type *scale,
133 octave_idx_type Ni, T extrapval, const T **x, 135 octave_idx_type Ni, DT extrapval, const T **x,
134 const T *v, const T **y, T *vi) 136 const DT *v, const T **y, DT *vi)
135 { 137 {
136 bool out = false; 138 bool out = false;
137 int bit; 139 int bit;
138 140
139 OCTAVE_LOCAL_BUFFER (T, coef, 2*n); 141 OCTAVE_LOCAL_BUFFER (T, coef, 2*n);
183 } 185 }
184 } 186 }
185 } 187 }
186 } 188 }
187 189
188 template <typename T, typename M> 190 template <typename MT, typename DMT, typename DT>
189 octave_value 191 octave_value
190 lin_interpn (int n, M *X, const M V, M *Y) 192 lin_interpn (int n, MT *X, const DMT V, MT *Y, DT extrapval)
191 { 193 {
194 static_assert(std::is_same<DT, typename DMT::element_type>::value,
195 "Type DMT must be an ArrayType with elements of type DT.");
196 using T = typename MT::element_type;
197
192 octave_value retval; 198 octave_value retval;
193 199
194 M Vi = M (Y[0].dims ()); 200 DMT Vi = DMT (Y[0].dims ());
195 201
196 OCTAVE_LOCAL_BUFFER (const T *, y, n); 202 OCTAVE_LOCAL_BUFFER (const T *, y, n);
197 OCTAVE_LOCAL_BUFFER (octave_idx_type, size, n); 203 OCTAVE_LOCAL_BUFFER (octave_idx_type, size, n);
198 204
199 for (int i = 0; i < n; i++) 205 for (int i = 0; i < n; i++)
203 } 209 }
204 210
205 OCTAVE_LOCAL_BUFFER (const T *, x, n); 211 OCTAVE_LOCAL_BUFFER (const T *, x, n);
206 OCTAVE_LOCAL_BUFFER (octave_idx_type, scale, n); 212 OCTAVE_LOCAL_BUFFER (octave_idx_type, scale, n);
207 213
208 const T *v = V.data (); 214 const DT *v = V.data ();
209 T *vi = Vi.fortran_vec (); 215 DT *vi = Vi.fortran_vec ();
210 octave_idx_type Ni = Vi.numel (); 216 octave_idx_type Ni = Vi.numel ();
211
212 T extrapval = octave_NA;
213 217
214 // offset in memory of each dimension 218 // offset in memory of each dimension
215 219
216 scale[0] = 1; 220 scale[0] = 1;
217 221
226 for (int i = 0; i < n; i++) 230 for (int i = 0; i < n; i++)
227 { 231 {
228 if (X[i].dims () != V.dims ()) 232 if (X[i].dims () != V.dims ())
229 error ("interpn: incompatible size of argument number %d", i+1); 233 error ("interpn: incompatible size of argument number %d", i+1);
230 234
231 M tmp = M (dim_vector (size[i], 1)); 235 MT tmp = MT (dim_vector (size[i], 1));
232 236
233 for (octave_idx_type j = 0; j < size[i]; j++) 237 for (octave_idx_type j = 0; j < size[i]; j++)
234 tmp(j) = X[i](scale[i]*j); 238 tmp(j) = X[i](scale[i]*j);
235 239
236 X[i] = tmp; 240 X[i] = tmp;
282 if (args(n).is_single_type ()) 286 if (args(n).is_single_type ())
283 { 287 {
284 OCTAVE_LOCAL_BUFFER (FloatNDArray, X, n); 288 OCTAVE_LOCAL_BUFFER (FloatNDArray, X, n);
285 OCTAVE_LOCAL_BUFFER (FloatNDArray, Y, n); 289 OCTAVE_LOCAL_BUFFER (FloatNDArray, Y, n);
286 290
287 const FloatNDArray V = args(n).float_array_value ();
288
289 for (int i = 0; i < n; i++) 291 for (int i = 0; i < n; i++)
290 { 292 {
291 X[i] = args(i).float_array_value (); 293 X[i] = args(i).float_array_value ();
292 Y[i] = args(n+i+1).float_array_value (); 294 Y[i] = args(n+i+1).float_array_value ();
293 295
294 if (Y[0].dims () != Y[i].dims ()) 296 if (Y[0].dims () != Y[i].dims ())
295 error ("interpn: incompatible size of argument number %d", n+i+2); 297 error ("interpn: incompatible size of argument number %d", n+i+2);
296 } 298 }
297 299
298 retval = lin_interpn<float, FloatNDArray> (n, X, V, Y); 300 if (args(n).iscomplex ())
301 {
302 const FloatComplexNDArray V = args(n).float_complex_array_value ();
303 FloatComplex extrapval (octave_NA, octave_NA);
304 retval = lin_interpn (n, X, V, Y, extrapval);
305 }
306 else
307 {
308 const FloatNDArray V = args(n).float_array_value ();
309 float extrapval = octave_NA;
310 retval = lin_interpn (n, X, V, Y, extrapval);
311 }
299 } 312 }
300 else 313 else
301 { 314 {
302 OCTAVE_LOCAL_BUFFER (NDArray, X, n); 315 OCTAVE_LOCAL_BUFFER (NDArray, X, n);
303 OCTAVE_LOCAL_BUFFER (NDArray, Y, n); 316 OCTAVE_LOCAL_BUFFER (NDArray, Y, n);
304 317
305 const NDArray V = args(n).array_value ();
306
307 for (int i = 0; i < n; i++) 318 for (int i = 0; i < n; i++)
308 { 319 {
309 X[i] = args(i).array_value (); 320 X[i] = args(i).array_value ();
310 Y[i] = args(n+i+1).array_value (); 321 Y[i] = args(n+i+1).array_value ();
311 322
312 if (Y[0].dims () != Y[i].dims ()) 323 if (Y[0].dims () != Y[i].dims ())
313 error ("interpn: incompatible size of argument number %d", n+i+2); 324 error ("interpn: incompatible size of argument number %d", n+i+2);
314 } 325 }
315 326
316 retval = lin_interpn<double, NDArray> (n, X, V, Y); 327 if (args(n).iscomplex ())
328 {
329 const ComplexNDArray V = args(n).complex_array_value ();
330 Complex extrapval (octave_NA, octave_NA);
331 retval = lin_interpn (n, X, V, Y, extrapval);
332 }
333 else
334 {
335 const NDArray V = args(n).array_value ();
336 double extrapval = octave_NA;
337 retval = lin_interpn (n, X, V, Y, extrapval);
338 }
317 } 339 }
318 340
319 return retval; 341 return retval;
320 } 342 }
321 343
322 /* 344 /*
323 ## No test needed for internal helper function. 345 ## Test that real and imaginary parts are interpolated the same way
324 %!assert (1) 346 ## and outer points are set to NA + 1i*NA
347 %!test <*61907>
348 %! x1 = 1:3;
349 %! x2 = 1:4;
350 %! v = repmat(1:4, 3, 1) + 1i * repmat((1:3)', 1, 4);
351 %! [XI2, XI1] = meshgrid(1.5:3.5, 1.5:3.5);
352 %! vi_complex = __lin_interpn__ (x1, x2, v, XI1, XI2);
353 %! vi_real = __lin_interpn__ (x1, x2, real (v), XI1, XI2);
354 %! vi_imag = __lin_interpn__ (x1, x2, imag (v), XI1, XI2);
355 %! assert (real (vi_complex), vi_real);
356 %! assert (imag (vi_complex), vi_imag);
325 */ 357 */
326 358
327 OCTAVE_NAMESPACE_END 359 OCTAVE_NAMESPACE_END