Mercurial > octave
comparison libinterp/corefcn/__lin_interpn__.cc @ 30688:d1fe2cb16d95
Fix __lin_interpn__ for complex input (bug #61907).
* libinterp/corefcn/__lin_interp__.cc (lin_interpn<T, DT>,
lin_interpn<MT, DMT, DT>): Generalize implementation for complex valued input.
* scripts/general/interpn.m: Remove work-around.
author | Christof Kaufmann <christofkaufmann.dev@gmail.com> |
---|---|
date | Wed, 26 Jan 2022 14:42:09 +0100 |
parents | 796f54d4ddbf |
children | e88a07dec498 |
comparison
equal
deleted
inserted
replaced
30687:97989152bfbe | 30688:d1fe2cb16d95 |
---|---|
33 | 33 |
34 #include "defun.h" | 34 #include "defun.h" |
35 #include "error.h" | 35 #include "error.h" |
36 #include "ovl.h" | 36 #include "ovl.h" |
37 | 37 |
38 #include <type_traits> | |
39 | |
38 OCTAVE_NAMESPACE_BEGIN | 40 OCTAVE_NAMESPACE_BEGIN |
39 | 41 |
40 // equivalent to isvector.m | 42 // equivalent to isvector.m |
41 | 43 |
42 template <typename T> | 44 template <typename T> |
125 } | 127 } |
126 } | 128 } |
127 | 129 |
128 // n-dimensional linear interpolation | 130 // n-dimensional linear interpolation |
129 | 131 |
130 template <typename T> | 132 template <typename T, typename DT> |
131 void | 133 void |
132 lin_interpn (int n, const octave_idx_type *size, const octave_idx_type *scale, | 134 lin_interpn (int n, const octave_idx_type *size, const octave_idx_type *scale, |
133 octave_idx_type Ni, T extrapval, const T **x, | 135 octave_idx_type Ni, DT extrapval, const T **x, |
134 const T *v, const T **y, T *vi) | 136 const DT *v, const T **y, DT *vi) |
135 { | 137 { |
136 bool out = false; | 138 bool out = false; |
137 int bit; | 139 int bit; |
138 | 140 |
139 OCTAVE_LOCAL_BUFFER (T, coef, 2*n); | 141 OCTAVE_LOCAL_BUFFER (T, coef, 2*n); |
183 } | 185 } |
184 } | 186 } |
185 } | 187 } |
186 } | 188 } |
187 | 189 |
188 template <typename T, typename M> | 190 template <typename MT, typename DMT, typename DT> |
189 octave_value | 191 octave_value |
190 lin_interpn (int n, M *X, const M V, M *Y) | 192 lin_interpn (int n, MT *X, const DMT V, MT *Y, DT extrapval) |
191 { | 193 { |
194 static_assert(std::is_same<DT, typename DMT::element_type>::value, | |
195 "Type DMT must be an ArrayType with elements of type DT."); | |
196 using T = typename MT::element_type; | |
197 | |
192 octave_value retval; | 198 octave_value retval; |
193 | 199 |
194 M Vi = M (Y[0].dims ()); | 200 DMT Vi = DMT (Y[0].dims ()); |
195 | 201 |
196 OCTAVE_LOCAL_BUFFER (const T *, y, n); | 202 OCTAVE_LOCAL_BUFFER (const T *, y, n); |
197 OCTAVE_LOCAL_BUFFER (octave_idx_type, size, n); | 203 OCTAVE_LOCAL_BUFFER (octave_idx_type, size, n); |
198 | 204 |
199 for (int i = 0; i < n; i++) | 205 for (int i = 0; i < n; i++) |
203 } | 209 } |
204 | 210 |
205 OCTAVE_LOCAL_BUFFER (const T *, x, n); | 211 OCTAVE_LOCAL_BUFFER (const T *, x, n); |
206 OCTAVE_LOCAL_BUFFER (octave_idx_type, scale, n); | 212 OCTAVE_LOCAL_BUFFER (octave_idx_type, scale, n); |
207 | 213 |
208 const T *v = V.data (); | 214 const DT *v = V.data (); |
209 T *vi = Vi.fortran_vec (); | 215 DT *vi = Vi.fortran_vec (); |
210 octave_idx_type Ni = Vi.numel (); | 216 octave_idx_type Ni = Vi.numel (); |
211 | |
212 T extrapval = octave_NA; | |
213 | 217 |
214 // offset in memory of each dimension | 218 // offset in memory of each dimension |
215 | 219 |
216 scale[0] = 1; | 220 scale[0] = 1; |
217 | 221 |
226 for (int i = 0; i < n; i++) | 230 for (int i = 0; i < n; i++) |
227 { | 231 { |
228 if (X[i].dims () != V.dims ()) | 232 if (X[i].dims () != V.dims ()) |
229 error ("interpn: incompatible size of argument number %d", i+1); | 233 error ("interpn: incompatible size of argument number %d", i+1); |
230 | 234 |
231 M tmp = M (dim_vector (size[i], 1)); | 235 MT tmp = MT (dim_vector (size[i], 1)); |
232 | 236 |
233 for (octave_idx_type j = 0; j < size[i]; j++) | 237 for (octave_idx_type j = 0; j < size[i]; j++) |
234 tmp(j) = X[i](scale[i]*j); | 238 tmp(j) = X[i](scale[i]*j); |
235 | 239 |
236 X[i] = tmp; | 240 X[i] = tmp; |
282 if (args(n).is_single_type ()) | 286 if (args(n).is_single_type ()) |
283 { | 287 { |
284 OCTAVE_LOCAL_BUFFER (FloatNDArray, X, n); | 288 OCTAVE_LOCAL_BUFFER (FloatNDArray, X, n); |
285 OCTAVE_LOCAL_BUFFER (FloatNDArray, Y, n); | 289 OCTAVE_LOCAL_BUFFER (FloatNDArray, Y, n); |
286 | 290 |
287 const FloatNDArray V = args(n).float_array_value (); | |
288 | |
289 for (int i = 0; i < n; i++) | 291 for (int i = 0; i < n; i++) |
290 { | 292 { |
291 X[i] = args(i).float_array_value (); | 293 X[i] = args(i).float_array_value (); |
292 Y[i] = args(n+i+1).float_array_value (); | 294 Y[i] = args(n+i+1).float_array_value (); |
293 | 295 |
294 if (Y[0].dims () != Y[i].dims ()) | 296 if (Y[0].dims () != Y[i].dims ()) |
295 error ("interpn: incompatible size of argument number %d", n+i+2); | 297 error ("interpn: incompatible size of argument number %d", n+i+2); |
296 } | 298 } |
297 | 299 |
298 retval = lin_interpn<float, FloatNDArray> (n, X, V, Y); | 300 if (args(n).iscomplex ()) |
301 { | |
302 const FloatComplexNDArray V = args(n).float_complex_array_value (); | |
303 FloatComplex extrapval (octave_NA, octave_NA); | |
304 retval = lin_interpn (n, X, V, Y, extrapval); | |
305 } | |
306 else | |
307 { | |
308 const FloatNDArray V = args(n).float_array_value (); | |
309 float extrapval = octave_NA; | |
310 retval = lin_interpn (n, X, V, Y, extrapval); | |
311 } | |
299 } | 312 } |
300 else | 313 else |
301 { | 314 { |
302 OCTAVE_LOCAL_BUFFER (NDArray, X, n); | 315 OCTAVE_LOCAL_BUFFER (NDArray, X, n); |
303 OCTAVE_LOCAL_BUFFER (NDArray, Y, n); | 316 OCTAVE_LOCAL_BUFFER (NDArray, Y, n); |
304 | 317 |
305 const NDArray V = args(n).array_value (); | |
306 | |
307 for (int i = 0; i < n; i++) | 318 for (int i = 0; i < n; i++) |
308 { | 319 { |
309 X[i] = args(i).array_value (); | 320 X[i] = args(i).array_value (); |
310 Y[i] = args(n+i+1).array_value (); | 321 Y[i] = args(n+i+1).array_value (); |
311 | 322 |
312 if (Y[0].dims () != Y[i].dims ()) | 323 if (Y[0].dims () != Y[i].dims ()) |
313 error ("interpn: incompatible size of argument number %d", n+i+2); | 324 error ("interpn: incompatible size of argument number %d", n+i+2); |
314 } | 325 } |
315 | 326 |
316 retval = lin_interpn<double, NDArray> (n, X, V, Y); | 327 if (args(n).iscomplex ()) |
328 { | |
329 const ComplexNDArray V = args(n).complex_array_value (); | |
330 Complex extrapval (octave_NA, octave_NA); | |
331 retval = lin_interpn (n, X, V, Y, extrapval); | |
332 } | |
333 else | |
334 { | |
335 const NDArray V = args(n).array_value (); | |
336 double extrapval = octave_NA; | |
337 retval = lin_interpn (n, X, V, Y, extrapval); | |
338 } | |
317 } | 339 } |
318 | 340 |
319 return retval; | 341 return retval; |
320 } | 342 } |
321 | 343 |
322 /* | 344 /* |
323 ## No test needed for internal helper function. | 345 ## Test that real and imaginary parts are interpolated the same way |
324 %!assert (1) | 346 ## and outer points are set to NA + 1i*NA |
347 %!test <*61907> | |
348 %! x1 = 1:3; | |
349 %! x2 = 1:4; | |
350 %! v = repmat(1:4, 3, 1) + 1i * repmat((1:3)', 1, 4); | |
351 %! [XI2, XI1] = meshgrid(1.5:3.5, 1.5:3.5); | |
352 %! vi_complex = __lin_interpn__ (x1, x2, v, XI1, XI2); | |
353 %! vi_real = __lin_interpn__ (x1, x2, real (v), XI1, XI2); | |
354 %! vi_imag = __lin_interpn__ (x1, x2, imag (v), XI1, XI2); | |
355 %! assert (real (vi_complex), vi_real); | |
356 %! assert (imag (vi_complex), vi_imag); | |
325 */ | 357 */ |
326 | 358 |
327 OCTAVE_NAMESPACE_END | 359 OCTAVE_NAMESPACE_END |