comparison libinterp/corefcn/find.cc @ 19008:80ca3b05d77c draft

New "dispatch" selects template argument from octave-value (Bug #42424, 42425) * find.cc (Ffind): This method now calls dispatch() rather than attempting to handle all matrix types on its own (findTemplated): Changed to a functor to be passed as a template template argument to dispatch() (findInfo): A struct that holds the other arguments to find (n_to_find, direction, nargout) Added unit tests for bugs 42424 and 42425 * (new file) dispatch.h (dispatch): A method for dispatching function calls to the right templated value based on an octave_value argument.
author David Spies <dnspies@gmail.com>
date Sat, 21 Jun 2014 13:13:05 -0600
parents 2e0613dadfee
children
comparison
equal deleted inserted replaced
19006:2e0613dadfee 19008:80ca3b05d77c
23 23
24 #ifdef HAVE_CONFIG_H 24 #ifdef HAVE_CONFIG_H
25 #include <config.h> 25 #include <config.h>
26 #endif 26 #endif
27 27
28 #include "dispatch.h"
28 #include "find.h" 29 #include "find.h"
29 30
30 #include "defun.h" 31 #include "defun.h"
31 #include "error.h" 32 #include "error.h"
32 #include "gripes.h" 33 #include "gripes.h"
33 #include "oct-obj.h" 34 #include "oct-obj.h"
34 35
35 namespace find 36 namespace find
36 { 37 {
38 // The find function should be seen as the canonical example demonstrating
39 // how to properly call dispatch.h
40 // It should always behave properly for all matrix types.
41 //
37 // ffind_result is a generic type used for storing the result of 42 // ffind_result is a generic type used for storing the result of
38 // a find operation. The way in which this result is stored will 43 // a find operation. The way in which this result is stored will
39 // vary based on whether the number of requested return values 44 // vary based on whether the number of requested return values
40 // is 1, 2, or 3. 45 // is 1, 2, or 3.
41 // Each instantiation of ffind_result must support a couple different 46 // Each instantiation of ffind_result must support a couple different
179 panic_impossible (); // Checked by *** in Ffind 184 panic_impossible (); // Checked by *** in Ffind
180 } 185 }
181 return octave_value_list (); 186 return octave_value_list ();
182 } 187 }
183 188
189 struct find_info
190 {
191 octave_idx_type n_to_find;
192 direction dir;
193 int nargout;
194 };
195
196 // This functor will be called by dispatch.h with the proper type M.
197 // This avoids having to explicitly list the different types find can
198 // handle and instead delegates that duty to the generic "dispatch"
199 // function.
184 template<typename M> 200 template<typename M>
185 octave_value_list 201 struct find_templated
186 find_templated (const M& v, int nargout, octave_idx_type n_to_find, 202 {
187 direction dir) 203 octave_value_list
188 { 204 operator() (const M& v, const find_info& inf)
189 return nargout_to_template (v, nargout, n_to_find, dir); 205 {
190 } 206 return nargout_to_template (v, inf.nargout, inf.n_to_find, inf.dir);
207 }
208 };
191 209
192 } 210 }
193 211
194 DEFUN (find, args, nargout, 212 DEFUN (find, args, nargout,
195 "-*- texinfo -*-\n\ 213 "-*- texinfo -*-\n\
253 @end group\n\ 271 @end group\n\
254 @end example\n\ 272 @end example\n\
255 @seealso{nonzeros}\n\ 273 @seealso{nonzeros}\n\
256 @end deftypefn") 274 @end deftypefn")
257 { 275 {
276 find::find_info inf;
277
278 if(nargout < 1)
279 nargout = 1;
280 else if(nargout > 3)
281 nargout = 3;
282
283 inf.nargout = nargout;
284
258 octave_value_list retval; 285 octave_value_list retval;
259 286
260 int nargin = args.length (); 287 int nargin = args.length ();
261 288
262 if (nargin > 3 || nargin < 1) 289 if (nargin > 3 || nargin < 1)
270 nargout = 1; 297 nargout = 1;
271 else if (nargout > 3) 298 else if (nargout > 3)
272 nargout = 3; 299 nargout = 3;
273 300
274 // Setup the default options. 301 // Setup the default options.
275 octave_idx_type n_to_find = -1; 302 inf.n_to_find = -1;
276 if (nargin > 1) 303 if (nargin > 1)
277 { 304 {
278 double val = args(1).scalar_value (); 305 double val = args(1).scalar_value ();
279 306
280 if (error_state || (val < 0 || (! xisinf (val) && val != xround (val)))) 307 if (error_state || (val < 0 || (! xisinf (val) && val != xround (val))))
281 { 308 {
282 error ("find: N must be a non-negative integer"); 309 error ("find: N must be a non-negative integer");
283 return retval; 310 return retval;
284 } 311 }
285 else if (! xisinf (val)) 312 else if (! xisinf (val))
286 n_to_find = val; 313 inf.n_to_find = val;
287 } 314 }
288 315
289 // Direction to do the searching. 316 // Direction to do the searching.
290 direction dir = FORWARD; 317 inf.dir = FORWARD;
291 if (nargin > 2) 318 if (nargin > 2)
292 { 319 {
293 std::string s_arg = args(2).string_value (); 320 std::string s_arg = args(2).string_value ();
294 321
295 if (error_state) 322 if (error_state)
296 { 323 {
297 error ("find: DIRECTION must be \"first\" or \"last\""); 324 error ("find: DIRECTION must be \"first\" or \"last\"");
298 return retval; 325 return retval;
299 } 326 }
300 if (s_arg == "first") 327 if (s_arg == "first")
301 dir = FORWARD; 328 inf.dir = FORWARD;
302 else if (s_arg == "last") 329 else if (s_arg == "last")
303 dir = BACKWARD; 330 inf.dir = BACKWARD;
304 else 331 else
305 { 332 {
306 error ("find: DIRECTION must be \"first\" or \"last\""); 333 error ("find: DIRECTION must be \"first\" or \"last\"");
307 return retval; 334 return retval;
308 } 335 }
309 } 336 }
310 337
311 octave_value arg = args(0); 338 const octave_value& arg = args(0);
312 339
313 if (arg.is_bool_type ()) 340 //For this special case, it's unnecessary to call dispatch because
314 { 341 //we already know the types of everything
315 if (arg.is_sparse_type ()) 342 if (arg.is_bool_type() && inf.nargout <= 1 && inf.n_to_find == -1)
316 { 343 {
317 SparseBoolMatrix v = arg.sparse_bool_matrix_value (); 344 // This case is equivalent to extracting indices from a logical
318 345 // matrix. Try to reuse the possibly cached index vector.
319 if (! error_state) 346 retval(0) = arg.index_vector ().unmask ();
320 retval = find::find_templated (v, nargout, n_to_find, dir); 347 return retval;
321 } 348 }
322 else if (nargout <= 1 && n_to_find == -1) 349
323 { 350 //Dispatches a call to the proper instantiation of the findTemplated
324 // This case is equivalent to extracting indices from a logical 351 //functor. This allows us to use the type of "arg" as a template
325 // matrix. Try to reuse the possibly cached index vector. 352 //argument to the find_to_iter function.
326 retval(0) = arg.index_vector ().unmask (); 353 return dispatch<find::find_templated> (arg, inf, "find");
327 }
328 else
329 {
330 boolNDArray v = arg.bool_array_value ();
331
332 if (! error_state)
333 retval = find::find_templated (v, nargout, n_to_find, dir);
334 }
335 }
336 else if (arg.is_integer_type ())
337 {
338 #define DO_INT_BRANCH(INTT) \
339 else if (arg.is_ ## INTT ## _type ()) \
340 { \
341 INTT ## NDArray v = arg.INTT ## _array_value (); \
342 \
343 if (! error_state) \
344 retval = find::find_templated (v, nargout, n_to_find, dir); \
345 }
346
347 if (false)
348 ;
349 DO_INT_BRANCH (int8)
350 DO_INT_BRANCH (int16)
351 DO_INT_BRANCH (int32)
352 DO_INT_BRANCH (int64)
353 DO_INT_BRANCH (uint8)
354 DO_INT_BRANCH (uint16)
355 DO_INT_BRANCH (uint32)
356 DO_INT_BRANCH (uint64)
357 else
358 panic_impossible ();
359 }
360 else if (arg.is_sparse_type ())
361 {
362 if (arg.is_real_type ())
363 {
364 SparseMatrix v = arg.sparse_matrix_value ();
365
366 if (! error_state)
367 retval = find::find_templated (v, nargout, n_to_find, dir);
368 }
369 else if (arg.is_complex_type ())
370 {
371 SparseComplexMatrix v = arg.sparse_complex_matrix_value ();
372
373 if (! error_state)
374 retval = find::find_templated (v, nargout, n_to_find, dir);
375 }
376 else
377 gripe_wrong_type_arg ("find", arg);
378 }
379 else if (arg.is_perm_matrix ())
380 {
381 PermMatrix P = arg.perm_matrix_value ();
382
383 if (! error_state)
384 retval = find::find_templated (P, nargout, n_to_find, dir);
385 }
386 else if (arg.is_string ())
387 {
388 charNDArray chnda = arg.char_array_value ();
389
390 if (! error_state)
391 retval = find::find_templated (chnda, nargout, n_to_find, dir);
392 }
393 else if (arg.is_single_type ())
394 {
395 if (arg.is_real_type ())
396 {
397 FloatNDArray nda = arg.float_array_value ();
398
399 if (! error_state)
400 retval = find::find_templated (nda, nargout, n_to_find, dir);
401 }
402 else if (arg.is_complex_type ())
403 {
404 FloatComplexNDArray cnda = arg.float_complex_array_value ();
405
406 if (! error_state)
407 retval = find::find_templated (cnda, nargout, n_to_find, dir);
408 }
409 }
410 else if (arg.is_real_type ())
411 {
412 NDArray nda = arg.array_value ();
413
414 if (! error_state)
415 retval = find::find_templated (nda, nargout, n_to_find, dir);
416 }
417 else if (arg.is_complex_type ())
418 {
419 ComplexNDArray cnda = arg.complex_array_value ();
420
421 if (! error_state)
422 retval = find::find_templated (cnda, nargout, n_to_find, dir);
423 }
424 else
425 gripe_wrong_type_arg ("find", arg);
426
427 return retval;
428 } 354 }
429 355
430 /* 356 /*
431 %!assert (find (char ([0, 97])), 2) 357 %!assert (find (char ([0, 97])), 2)
432 %!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5]) 358 %!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5])
479 %! x = sparse(100000, 30000); 405 %! x = sparse(100000, 30000);
480 %! x(end, end) = 1; 406 %! x(end, end) = 1;
481 %! i = find(x); 407 %! i = find(x);
482 %! assert (i == 3e09); 408 %! assert (i == 3e09);
483 409
410 %!test
411 %! fail("[a,b,c,d,e,f] = find(speye(3));")
412
413 %!test
414 %! [i,j] = find(eye(1000000));
415
484 %!error find () 416 %!error find ()
485 */ 417 */