Mercurial > octave-dspies
diff libinterp/corefcn/find.cc @ 19006:2e0613dadfee draft
All calls to "find" use the same generic implementation (bug #42408, 42421)
* find.cc: Rewrite.
Move generic "find" logic to find.h
(Ffind) : Changed calls to find_nonzero_elem_idx to find_templated
Added unit test for bug #42421
* Array.cc (and .h) (Array::find): Deleted function. Replaced with find::find(Array)
from find.h
* Array.h: Added typedef for array_iterator (in nz-iterators.h) as
Array::iter_type
* DiagArray2.h: Added typedef for diag_iterator (in nz-iterators.h) as
DiagArray2::iter_type
* PermMatrix.h: Added typedef for perm_iterator (in nz-iterators.h) as
PermMatrix::iter_type
Also added typedef for bool as PermMatrix::element_type
(not octave_idx_type)
Added an nnz() function (which is an alias for perm_length) and a
perm_elem(i) function for retrieving the ith element of the permutation
* Sparse.h: Added typedef for sparse_iterator (in nz-iterators.h) as
Sparse::iter_type
Added a short comment documenting the the argument to the numel
function
* idx-vector.cc (idx_vector::idx_mask_rep::as_array): Changed Array.find to
find::find(Array) (in find.h)
* (new file) find.h
* (new file) interp-idx.h: Simple methods for converting between interpreter
index type and internal octave_idx_type/row-col pair
* (new file) min-with-nnz.h: Fast methods for taking an arbitrary matrix M and
an octave_idx_type n and finding min(M.nnz(), n)
* (new file) nz-iterators.h: Iterators for traversing (in column-major order)
the nonzero elements of any array or matrix backwards or forwards
* (new file) direction.h: Generic methods for simplifying code has to deal with
a "backwards or forwards" template argument
* build-sparse-tests.sh: Removed 5-return-value calls to "find" in unit-tests;
Admittedly this commit breaks this "feature" which was undocumented and only
partially supported to begin with (ie never worked for full matrices,
permutation matrices, or diagonal matrices)
author | David Spies <dnspies@gmail.com> |
---|---|
date | Tue, 17 Jun 2014 16:41:11 -0600 |
parents | aa9ca67f09fb |
children | 80ca3b05d77c |
line wrap: on
line diff
--- a/libinterp/corefcn/find.cc Mon Aug 11 09:39:45 2014 -0700 +++ b/libinterp/corefcn/find.cc Tue Jun 17 16:41:11 2014 -0600 @@ -1,6 +1,7 @@ /* Copyright (C) 1996-2013 John W. Eaton +Copyright (C) 2014 David Spies This file is part of Octave. @@ -24,305 +25,170 @@ #include <config.h> #endif -#include "quit.h" +#include "find.h" #include "defun.h" #include "error.h" #include "gripes.h" #include "oct-obj.h" -// Find at most N_TO_FIND nonzero elements in NDA. Search forward if -// DIRECTION is 1, backward if it is -1. NARGOUT is the number of -// output arguments. If N_TO_FIND is -1, find all nonzero elements. - -template <typename T> -octave_value_list -find_nonzero_elem_idx (const Array<T>& nda, int nargout, - octave_idx_type n_to_find, int direction) +namespace find { - octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); + // ffind_result is a generic type used for storing the result of + // a find operation. The way in which this result is stored will + // vary based on whether the number of requested return values + // is 1, 2, or 3. + // Each instantiation of ffind_result must support a couple different + // operations. It is constructed with a dim_vector which indicates + // the size of the return vectors (the size of all return values for + // find is the same). It supports an add() operation which + // the generic "find" method will call to add an element to the return + // values. "add" takes an index and an nz-iterator for the matrix type + // being searched (from nz-iterators.h) + // Finally it supports get_list which returns an octave_value_list of the + // return values (1,2, or 3 depending on the nargout template argument) - Array<octave_idx_type> idx; - if (n_to_find >= 0) - idx = nda.find (n_to_find, direction == -1); - else - idx = nda.find (); + template<int nargout, typename T> + struct ffind_result; - // The maximum element is always at the end. - octave_idx_type iext = idx.is_empty () ? 0 : idx.xelem (idx.numel () - 1) + 1; + template<typename T> + struct ffind_result<1, T> + { + ffind_result (void) { } + ffind_result (const dim_vector& nnz) : res (nnz) { } - switch (nargout) + Array<double> res; + + template<typename iter_t> + void + add (octave_idx_type place, const iter_t& iter) { - default: - case 3: - retval(2) = Array<T> (nda.index (idx_vector (idx))); - // Fall through! - - case 2: - { - Array<octave_idx_type> jdx (idx.dims ()); - octave_idx_type n = idx.length (); - octave_idx_type nr = nda.rows (); - for (octave_idx_type i = 0; i < n; i++) - { - jdx.xelem (i) = idx.xelem (i) / nr; - idx.xelem (i) %= nr; - } - iext = -1; - retval(1) = idx_vector (jdx, -1); - } - // Fall through! - - case 1: - case 0: - retval(0) = idx_vector (idx, iext); - break; + res.xelem (place) = iter.interp_idx (); } - return retval; -} - -template <typename T> -octave_value_list -find_nonzero_elem_idx (const Sparse<T>& v, int nargout, - octave_idx_type n_to_find, int direction) -{ - octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); - - octave_idx_type nr = v.rows (); - octave_idx_type nc = v.cols (); - octave_idx_type nz = v.nnz (); - - // Search in the default range. - octave_idx_type start_nc = -1; - octave_idx_type end_nc = -1; - octave_idx_type count; - - // Search for the range to search - if (n_to_find < 0) - { - start_nc = 0; - end_nc = nc; - n_to_find = nz; - count = nz; - } - else if (direction > 0) + octave_value_list + get_list (void) { - for (octave_idx_type j = 0; j < nc; j++) - { - OCTAVE_QUIT; - if (v.cidx (j) == 0 && v.cidx (j+1) != 0) - start_nc = j; - if (v.cidx (j+1) >= n_to_find) - { - end_nc = j + 1; - break; - } - } + return octave_value_list (octave_value (res)); } - else + }; + + template<typename T> + struct ffind_result<2, T> + { + ffind_result (void) { } + ffind_result (const dim_vector& nnz) : rescol (nnz), resrow (nnz) { } + + Array<double> rescol; + Array<double> resrow; + + template<typename iter_t> + void + add (octave_idx_type place, const iter_t& iter) { - for (octave_idx_type j = nc; j > 0; j--) - { - OCTAVE_QUIT; - if (v.cidx (j) == nz && v.cidx (j-1) != nz) - end_nc = j; - if (nz - v.cidx (j-1) >= n_to_find) - { - start_nc = j - 1; - break; - } - } - } - - count = (n_to_find > v.cidx (end_nc) - v.cidx (start_nc) ? - v.cidx (end_nc) - v.cidx (start_nc) : n_to_find); - - octave_idx_type result_nr; - octave_idx_type result_nc; - - // Default case is to return a column vector, however, if the original - // argument was a row vector, then force return of a row vector. - if (nr == 1) - { - result_nr = 1; - result_nc = count; - } - else - { - result_nr = count; - result_nc = 1; + rescol.xelem (place) = to_interp_idx (iter.col ()); + resrow.xelem (place) = to_interp_idx (iter.row ()); } - Matrix idx (result_nr, result_nc); - - Matrix i_idx (result_nr, result_nc); - Matrix j_idx (result_nr, result_nc); - - Array<T> val (dim_vector (result_nr, result_nc)); - - if (count > 0) + octave_value_list + get_list (void) { - // Search for elements to return. Only search the region where there - // are elements to be found using the count that we want to find. - for (octave_idx_type j = start_nc, cx = 0; j < end_nc; j++) - for (octave_idx_type i = v.cidx (j); i < v.cidx (j+1); i++) - { - OCTAVE_QUIT; - if (direction < 0 && i < nz - count) - continue; - i_idx(cx) = static_cast<double> (v.ridx (i) + 1); - j_idx(cx) = static_cast<double> (j + 1); - idx(cx) = j * nr + v.ridx (i) + 1; - val(cx) = v.data(i); - cx++; - if (cx == count) - break; - } + octave_value_list res (2); + res.xelem (0) = resrow; + res.xelem (1) = rescol; + return res; } - else - { - // No items found. Fixup return dimensions for Matlab compatibility. - // The behavior to match is documented in Array.cc (Array<T>::find). - if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1)) - { - idx.resize (0, 0); + }; - i_idx.resize (0, 0); - j_idx.resize (0, 0); - - val.resize (dim_vector (0, 0)); - } - } - - switch (nargout) - { - case 0: - case 1: - retval(0) = idx; - break; + template<typename T> + struct ffind_result<3, T> + { + ffind_result (void) { } + ffind_result (const dim_vector& nnz) : + rescol (nnz), resrow (nnz), elems (nnz) { } + Array<double> rescol; + Array<double> resrow; + Array<T> elems; - case 5: - retval(4) = nc; - // Fall through - - case 4: - retval(3) = nr; - // Fall through - - case 3: - retval(2) = val; - // Fall through! - - case 2: - retval(1) = j_idx; - retval(0) = i_idx; - break; - - default: - panic_impossible (); - break; + template<typename iter_t> + void + add (octave_idx_type place, const iter_t& iter) + { + rescol.xelem (place) = to_interp_idx (iter.col ()); + resrow.xelem (place) = to_interp_idx (iter.row ()); + elems.xelem (place) = iter.data (); } - return retval; -} - -octave_value_list -find_nonzero_elem_idx (const PermMatrix& v, int nargout, - octave_idx_type n_to_find, int direction) -{ - // There are far fewer special cases to handle for a PermMatrix. - octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); - - octave_idx_type nr = v.rows (); - octave_idx_type nc = v.cols (); - octave_idx_type start_nc, count; - - // Determine the range to search. - if (n_to_find < 0 || n_to_find >= nc) - { - start_nc = 0; - count = nc; - } - else if (direction > 0) + octave_value_list + get_list (void) { - start_nc = 0; - count = n_to_find; - } - else - { - start_nc = nc - n_to_find; - count = n_to_find; - } - - Matrix idx (count, 1); - Matrix i_idx (count, 1); - Matrix j_idx (count, 1); - // Every value is 1. - Array<double> val (dim_vector (count, 1), 1.0); - - if (count > 0) - { - const Array<octave_idx_type>& p = v.col_perm_vec (); - for (octave_idx_type k = 0; k < count; k++) - { - OCTAVE_QUIT; - const octave_idx_type j = start_nc + k; - const octave_idx_type i = p(j); - i_idx(k) = static_cast<double> (1+i); - j_idx(k) = static_cast<double> (1+j); - idx(k) = j * nc + i + 1; - } + octave_value_list res (3); + res.xelem (0) = resrow; + res.xelem (1) = rescol; + res.xelem (2) = elems; + return res; } - else - { - // FIXME: Is this case even possible? A scalar permutation matrix seems - // to devolve to a scalar full matrix, at least from the Octave command - // line. Perhaps this function could be called internally from C++ with - // such a matrix. - // No items found. Fixup return dimensions for Matlab compatibility. - // The behavior to match is documented in Array.cc (Array<T>::find). - if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1)) - { - idx.resize (0, 0); + }; - i_idx.resize (0, 0); - j_idx.resize (0, 0); + // Calls the to_R template method in "find.h" which + // in turn will fill in resvec of type ffind_result (see above). + // This is generic enough to work for any matrix type M + template<direction dir, int nargout, typename M> + octave_value_list + call (const M& v, octave_idx_type n_to_find) + { + ffind_result<nargout, typename M::element_type> resvec; + dir_handler<dir> dirc; + resvec = + find_to_R<ffind_result<nargout, typename M::element_type> > (dirc, + v, n_to_find); - val.resize (dim_vector (0, 0)); - } - } + return resvec.get_list (); + } - switch (nargout) - { - case 0: - case 1: - retval(0) = idx; - break; - - case 5: - retval(4) = nc; - // Fall through - - case 4: - retval(3) = nc; - // Fall through + template<int nargout, typename M> + octave_value_list + dir_to_template (const M& v, octave_idx_type n_to_find, direction dir) + { + switch (dir) + { + case BACKWARD: + return call<BACKWARD, nargout> (v, n_to_find); + case FORWARD: + return call<FORWARD, nargout> (v, n_to_find); + default: + panic_impossible (); + } + return octave_value_list (); + } - case 3: - retval(2) = val; - // Fall through! + template<typename M> + octave_value_list + nargout_to_template (const M& v, int nargout, octave_idx_type n_to_find, + direction dir) + { + switch (nargout) + { + case 1: + return dir_to_template<1> (v, n_to_find, dir); + case 2: + return dir_to_template<2> (v, n_to_find, dir); + case 3: + return dir_to_template<3> (v, n_to_find, dir); + default: + panic_impossible (); // Checked by *** in Ffind + } + return octave_value_list (); + } - case 2: - retval(1) = j_idx; - retval(0) = i_idx; - break; + template<typename M> + octave_value_list + find_templated (const M& v, int nargout, octave_idx_type n_to_find, + direction dir) + { + return nargout_to_template (v, nargout, n_to_find, dir); + } - default: - panic_impossible (); - break; - } - - return retval; } DEFUN (find, args, nargout, @@ -399,6 +265,12 @@ return retval; } + // *** + if (nargout < 1) + nargout = 1; + else if (nargout > 3) + nargout = 3; + // Setup the default options. octave_idx_type n_to_find = -1; if (nargin > 1) @@ -414,23 +286,22 @@ n_to_find = val; } - // Direction to do the searching (1 == forward, -1 == reverse). - int direction = 1; + // Direction to do the searching. + direction dir = FORWARD; if (nargin > 2) { - direction = 0; - std::string s_arg = args(2).string_value (); - if (! error_state) + if (error_state) { - if (s_arg == "first") - direction = 1; - else if (s_arg == "last") - direction = -1; + error ("find: DIRECTION must be \"first\" or \"last\""); + return retval; } - - if (direction == 0) + if (s_arg == "first") + dir = FORWARD; + else if (s_arg == "last") + dir = BACKWARD; + else { error ("find: DIRECTION must be \"first\" or \"last\""); return retval; @@ -446,10 +317,9 @@ SparseBoolMatrix v = arg.sparse_bool_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } - else if (nargout <= 1 && n_to_find == -1 && direction == 1) + else if (nargout <= 1 && n_to_find == -1) { // This case is equivalent to extracting indices from a logical // matrix. Try to reuse the possibly cached index vector. @@ -460,20 +330,18 @@ boolNDArray v = arg.bool_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } } else if (arg.is_integer_type ()) { -#define DO_INT_BRANCH(INTT) \ - else if (arg.is_ ## INTT ## _type ()) \ - { \ - INTT ## NDArray v = arg.INTT ## _array_value (); \ - \ - if (! error_state) \ - retval = find_nonzero_elem_idx (v, nargout, \ - n_to_find, direction);\ +#define DO_INT_BRANCH(INTT) \ + else if (arg.is_ ## INTT ## _type ()) \ + { \ + INTT ## NDArray v = arg.INTT ## _array_value (); \ + \ + if (! error_state) \ + retval = find::find_templated (v, nargout, n_to_find, dir); \ } if (false) @@ -496,16 +364,14 @@ SparseMatrix v = arg.sparse_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } else if (arg.is_complex_type ()) { SparseComplexMatrix v = arg.sparse_complex_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } else gripe_wrong_type_arg ("find", arg); @@ -515,14 +381,14 @@ PermMatrix P = arg.perm_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction); + retval = find::find_templated (P, nargout, n_to_find, dir); } else if (arg.is_string ()) { charNDArray chnda = arg.char_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (chnda, nargout, n_to_find, direction); + retval = find::find_templated (chnda, nargout, n_to_find, dir); } else if (arg.is_single_type ()) { @@ -531,16 +397,14 @@ FloatNDArray nda = arg.float_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (nda, nargout, n_to_find, - direction); + retval = find::find_templated (nda, nargout, n_to_find, dir); } else if (arg.is_complex_type ()) { FloatComplexNDArray cnda = arg.float_complex_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, - direction); + retval = find::find_templated (cnda, nargout, n_to_find, dir); } } else if (arg.is_real_type ()) @@ -548,14 +412,14 @@ NDArray nda = arg.array_value (); if (! error_state) - retval = find_nonzero_elem_idx (nda, nargout, n_to_find, direction); + retval = find::find_templated (nda, nargout, n_to_find, dir); } else if (arg.is_complex_type ()) { ComplexNDArray cnda = arg.complex_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, direction); + retval = find::find_templated (cnda, nargout, n_to_find, dir); } else gripe_wrong_type_arg ("find", arg); @@ -611,5 +475,11 @@ %!assert (find ([2 0 1 0 5 0], Inf), [1, 3, 5]) %!assert (find ([2 0 1 0 5 0], Inf, "last"), [1, 3, 5]) +%!test +%! x = sparse(100000, 30000); +%! x(end, end) = 1; +%! i = find(x); +%! assert (i == 3e09); + %!error find () */