Mercurial > octave-dspies
comparison libinterp/corefcn/find.cc @ 19008:80ca3b05d77c draft
New "dispatch" selects template argument from octave-value (Bug #42424, 42425)
* find.cc (Ffind): This method now calls dispatch() rather than attempting to
handle all matrix types on its own
(findTemplated): Changed to a functor to be passed as a template template
argument to dispatch()
(findInfo): A struct that holds the other arguments to find
(n_to_find, direction, nargout)
Added unit tests for bugs 42424 and 42425
* (new file) dispatch.h (dispatch): A method for dispatching function calls to
the right templated value based on an octave_value argument.
author | David Spies <dnspies@gmail.com> |
---|---|
date | Sat, 21 Jun 2014 13:13:05 -0600 |
parents | 2e0613dadfee |
children |
comparison
equal
deleted
inserted
replaced
19006:2e0613dadfee | 19008:80ca3b05d77c |
---|---|
23 | 23 |
24 #ifdef HAVE_CONFIG_H | 24 #ifdef HAVE_CONFIG_H |
25 #include <config.h> | 25 #include <config.h> |
26 #endif | 26 #endif |
27 | 27 |
28 #include "dispatch.h" | |
28 #include "find.h" | 29 #include "find.h" |
29 | 30 |
30 #include "defun.h" | 31 #include "defun.h" |
31 #include "error.h" | 32 #include "error.h" |
32 #include "gripes.h" | 33 #include "gripes.h" |
33 #include "oct-obj.h" | 34 #include "oct-obj.h" |
34 | 35 |
35 namespace find | 36 namespace find |
36 { | 37 { |
38 // The find function should be seen as the canonical example demonstrating | |
39 // how to properly call dispatch.h | |
40 // It should always behave properly for all matrix types. | |
41 // | |
37 // ffind_result is a generic type used for storing the result of | 42 // ffind_result is a generic type used for storing the result of |
38 // a find operation. The way in which this result is stored will | 43 // a find operation. The way in which this result is stored will |
39 // vary based on whether the number of requested return values | 44 // vary based on whether the number of requested return values |
40 // is 1, 2, or 3. | 45 // is 1, 2, or 3. |
41 // Each instantiation of ffind_result must support a couple different | 46 // Each instantiation of ffind_result must support a couple different |
179 panic_impossible (); // Checked by *** in Ffind | 184 panic_impossible (); // Checked by *** in Ffind |
180 } | 185 } |
181 return octave_value_list (); | 186 return octave_value_list (); |
182 } | 187 } |
183 | 188 |
189 struct find_info | |
190 { | |
191 octave_idx_type n_to_find; | |
192 direction dir; | |
193 int nargout; | |
194 }; | |
195 | |
196 // This functor will be called by dispatch.h with the proper type M. | |
197 // This avoids having to explicitly list the different types find can | |
198 // handle and instead delegates that duty to the generic "dispatch" | |
199 // function. | |
184 template<typename M> | 200 template<typename M> |
185 octave_value_list | 201 struct find_templated |
186 find_templated (const M& v, int nargout, octave_idx_type n_to_find, | 202 { |
187 direction dir) | 203 octave_value_list |
188 { | 204 operator() (const M& v, const find_info& inf) |
189 return nargout_to_template (v, nargout, n_to_find, dir); | 205 { |
190 } | 206 return nargout_to_template (v, inf.nargout, inf.n_to_find, inf.dir); |
207 } | |
208 }; | |
191 | 209 |
192 } | 210 } |
193 | 211 |
194 DEFUN (find, args, nargout, | 212 DEFUN (find, args, nargout, |
195 "-*- texinfo -*-\n\ | 213 "-*- texinfo -*-\n\ |
253 @end group\n\ | 271 @end group\n\ |
254 @end example\n\ | 272 @end example\n\ |
255 @seealso{nonzeros}\n\ | 273 @seealso{nonzeros}\n\ |
256 @end deftypefn") | 274 @end deftypefn") |
257 { | 275 { |
276 find::find_info inf; | |
277 | |
278 if(nargout < 1) | |
279 nargout = 1; | |
280 else if(nargout > 3) | |
281 nargout = 3; | |
282 | |
283 inf.nargout = nargout; | |
284 | |
258 octave_value_list retval; | 285 octave_value_list retval; |
259 | 286 |
260 int nargin = args.length (); | 287 int nargin = args.length (); |
261 | 288 |
262 if (nargin > 3 || nargin < 1) | 289 if (nargin > 3 || nargin < 1) |
270 nargout = 1; | 297 nargout = 1; |
271 else if (nargout > 3) | 298 else if (nargout > 3) |
272 nargout = 3; | 299 nargout = 3; |
273 | 300 |
274 // Setup the default options. | 301 // Setup the default options. |
275 octave_idx_type n_to_find = -1; | 302 inf.n_to_find = -1; |
276 if (nargin > 1) | 303 if (nargin > 1) |
277 { | 304 { |
278 double val = args(1).scalar_value (); | 305 double val = args(1).scalar_value (); |
279 | 306 |
280 if (error_state || (val < 0 || (! xisinf (val) && val != xround (val)))) | 307 if (error_state || (val < 0 || (! xisinf (val) && val != xround (val)))) |
281 { | 308 { |
282 error ("find: N must be a non-negative integer"); | 309 error ("find: N must be a non-negative integer"); |
283 return retval; | 310 return retval; |
284 } | 311 } |
285 else if (! xisinf (val)) | 312 else if (! xisinf (val)) |
286 n_to_find = val; | 313 inf.n_to_find = val; |
287 } | 314 } |
288 | 315 |
289 // Direction to do the searching. | 316 // Direction to do the searching. |
290 direction dir = FORWARD; | 317 inf.dir = FORWARD; |
291 if (nargin > 2) | 318 if (nargin > 2) |
292 { | 319 { |
293 std::string s_arg = args(2).string_value (); | 320 std::string s_arg = args(2).string_value (); |
294 | 321 |
295 if (error_state) | 322 if (error_state) |
296 { | 323 { |
297 error ("find: DIRECTION must be \"first\" or \"last\""); | 324 error ("find: DIRECTION must be \"first\" or \"last\""); |
298 return retval; | 325 return retval; |
299 } | 326 } |
300 if (s_arg == "first") | 327 if (s_arg == "first") |
301 dir = FORWARD; | 328 inf.dir = FORWARD; |
302 else if (s_arg == "last") | 329 else if (s_arg == "last") |
303 dir = BACKWARD; | 330 inf.dir = BACKWARD; |
304 else | 331 else |
305 { | 332 { |
306 error ("find: DIRECTION must be \"first\" or \"last\""); | 333 error ("find: DIRECTION must be \"first\" or \"last\""); |
307 return retval; | 334 return retval; |
308 } | 335 } |
309 } | 336 } |
310 | 337 |
311 octave_value arg = args(0); | 338 const octave_value& arg = args(0); |
312 | 339 |
313 if (arg.is_bool_type ()) | 340 //For this special case, it's unnecessary to call dispatch because |
314 { | 341 //we already know the types of everything |
315 if (arg.is_sparse_type ()) | 342 if (arg.is_bool_type() && inf.nargout <= 1 && inf.n_to_find == -1) |
316 { | 343 { |
317 SparseBoolMatrix v = arg.sparse_bool_matrix_value (); | 344 // This case is equivalent to extracting indices from a logical |
318 | 345 // matrix. Try to reuse the possibly cached index vector. |
319 if (! error_state) | 346 retval(0) = arg.index_vector ().unmask (); |
320 retval = find::find_templated (v, nargout, n_to_find, dir); | 347 return retval; |
321 } | 348 } |
322 else if (nargout <= 1 && n_to_find == -1) | 349 |
323 { | 350 //Dispatches a call to the proper instantiation of the findTemplated |
324 // This case is equivalent to extracting indices from a logical | 351 //functor. This allows us to use the type of "arg" as a template |
325 // matrix. Try to reuse the possibly cached index vector. | 352 //argument to the find_to_iter function. |
326 retval(0) = arg.index_vector ().unmask (); | 353 return dispatch<find::find_templated> (arg, inf, "find"); |
327 } | |
328 else | |
329 { | |
330 boolNDArray v = arg.bool_array_value (); | |
331 | |
332 if (! error_state) | |
333 retval = find::find_templated (v, nargout, n_to_find, dir); | |
334 } | |
335 } | |
336 else if (arg.is_integer_type ()) | |
337 { | |
338 #define DO_INT_BRANCH(INTT) \ | |
339 else if (arg.is_ ## INTT ## _type ()) \ | |
340 { \ | |
341 INTT ## NDArray v = arg.INTT ## _array_value (); \ | |
342 \ | |
343 if (! error_state) \ | |
344 retval = find::find_templated (v, nargout, n_to_find, dir); \ | |
345 } | |
346 | |
347 if (false) | |
348 ; | |
349 DO_INT_BRANCH (int8) | |
350 DO_INT_BRANCH (int16) | |
351 DO_INT_BRANCH (int32) | |
352 DO_INT_BRANCH (int64) | |
353 DO_INT_BRANCH (uint8) | |
354 DO_INT_BRANCH (uint16) | |
355 DO_INT_BRANCH (uint32) | |
356 DO_INT_BRANCH (uint64) | |
357 else | |
358 panic_impossible (); | |
359 } | |
360 else if (arg.is_sparse_type ()) | |
361 { | |
362 if (arg.is_real_type ()) | |
363 { | |
364 SparseMatrix v = arg.sparse_matrix_value (); | |
365 | |
366 if (! error_state) | |
367 retval = find::find_templated (v, nargout, n_to_find, dir); | |
368 } | |
369 else if (arg.is_complex_type ()) | |
370 { | |
371 SparseComplexMatrix v = arg.sparse_complex_matrix_value (); | |
372 | |
373 if (! error_state) | |
374 retval = find::find_templated (v, nargout, n_to_find, dir); | |
375 } | |
376 else | |
377 gripe_wrong_type_arg ("find", arg); | |
378 } | |
379 else if (arg.is_perm_matrix ()) | |
380 { | |
381 PermMatrix P = arg.perm_matrix_value (); | |
382 | |
383 if (! error_state) | |
384 retval = find::find_templated (P, nargout, n_to_find, dir); | |
385 } | |
386 else if (arg.is_string ()) | |
387 { | |
388 charNDArray chnda = arg.char_array_value (); | |
389 | |
390 if (! error_state) | |
391 retval = find::find_templated (chnda, nargout, n_to_find, dir); | |
392 } | |
393 else if (arg.is_single_type ()) | |
394 { | |
395 if (arg.is_real_type ()) | |
396 { | |
397 FloatNDArray nda = arg.float_array_value (); | |
398 | |
399 if (! error_state) | |
400 retval = find::find_templated (nda, nargout, n_to_find, dir); | |
401 } | |
402 else if (arg.is_complex_type ()) | |
403 { | |
404 FloatComplexNDArray cnda = arg.float_complex_array_value (); | |
405 | |
406 if (! error_state) | |
407 retval = find::find_templated (cnda, nargout, n_to_find, dir); | |
408 } | |
409 } | |
410 else if (arg.is_real_type ()) | |
411 { | |
412 NDArray nda = arg.array_value (); | |
413 | |
414 if (! error_state) | |
415 retval = find::find_templated (nda, nargout, n_to_find, dir); | |
416 } | |
417 else if (arg.is_complex_type ()) | |
418 { | |
419 ComplexNDArray cnda = arg.complex_array_value (); | |
420 | |
421 if (! error_state) | |
422 retval = find::find_templated (cnda, nargout, n_to_find, dir); | |
423 } | |
424 else | |
425 gripe_wrong_type_arg ("find", arg); | |
426 | |
427 return retval; | |
428 } | 354 } |
429 | 355 |
430 /* | 356 /* |
431 %!assert (find (char ([0, 97])), 2) | 357 %!assert (find (char ([0, 97])), 2) |
432 %!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5]) | 358 %!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5]) |
479 %! x = sparse(100000, 30000); | 405 %! x = sparse(100000, 30000); |
480 %! x(end, end) = 1; | 406 %! x(end, end) = 1; |
481 %! i = find(x); | 407 %! i = find(x); |
482 %! assert (i == 3e09); | 408 %! assert (i == 3e09); |
483 | 409 |
410 %!test | |
411 %! fail("[a,b,c,d,e,f] = find(speye(3));") | |
412 | |
413 %!test | |
414 %! [i,j] = find(eye(1000000)); | |
415 | |
484 %!error find () | 416 %!error find () |
485 */ | 417 */ |