# HG changeset patch # User David Spies # Date 1403044871 21600 # Node ID 2e0613dadfee107ae067f873175d82bbc9155c2e # Parent d00f6b09258fa5632ff45aa345bf6594d9f232a8 All calls to "find" use the same generic implementation (bug #42408, 42421) * find.cc: Rewrite. Move generic "find" logic to find.h (Ffind) : Changed calls to find_nonzero_elem_idx to find_templated Added unit test for bug #42421 * Array.cc (and .h) (Array::find): Deleted function. Replaced with find::find(Array) from find.h * Array.h: Added typedef for array_iterator (in nz-iterators.h) as Array::iter_type * DiagArray2.h: Added typedef for diag_iterator (in nz-iterators.h) as DiagArray2::iter_type * PermMatrix.h: Added typedef for perm_iterator (in nz-iterators.h) as PermMatrix::iter_type Also added typedef for bool as PermMatrix::element_type (not octave_idx_type) Added an nnz() function (which is an alias for perm_length) and a perm_elem(i) function for retrieving the ith element of the permutation * Sparse.h: Added typedef for sparse_iterator (in nz-iterators.h) as Sparse::iter_type Added a short comment documenting the the argument to the numel function * idx-vector.cc (idx_vector::idx_mask_rep::as_array): Changed Array.find to find::find(Array) (in find.h) * (new file) find.h * (new file) interp-idx.h: Simple methods for converting between interpreter index type and internal octave_idx_type/row-col pair * (new file) min-with-nnz.h: Fast methods for taking an arbitrary matrix M and an octave_idx_type n and finding min(M.nnz(), n) * (new file) nz-iterators.h: Iterators for traversing (in column-major order) the nonzero elements of any array or matrix backwards or forwards * (new file) direction.h: Generic methods for simplifying code has to deal with a "backwards or forwards" template argument * build-sparse-tests.sh: Removed 5-return-value calls to "find" in unit-tests; Admittedly this commit breaks this "feature" which was undocumented and only partially supported to begin with (ie never worked for full matrices, permutation matrices, or diagonal matrices) diff -r d00f6b09258f -r 2e0613dadfee libinterp/corefcn/find.cc --- a/libinterp/corefcn/find.cc Mon Aug 11 09:39:45 2014 -0700 +++ b/libinterp/corefcn/find.cc Tue Jun 17 16:41:11 2014 -0600 @@ -1,6 +1,7 @@ /* Copyright (C) 1996-2013 John W. Eaton +Copyright (C) 2014 David Spies This file is part of Octave. @@ -24,305 +25,170 @@ #include #endif -#include "quit.h" +#include "find.h" #include "defun.h" #include "error.h" #include "gripes.h" #include "oct-obj.h" -// Find at most N_TO_FIND nonzero elements in NDA. Search forward if -// DIRECTION is 1, backward if it is -1. NARGOUT is the number of -// output arguments. If N_TO_FIND is -1, find all nonzero elements. - -template -octave_value_list -find_nonzero_elem_idx (const Array& nda, int nargout, - octave_idx_type n_to_find, int direction) +namespace find { - octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); + // ffind_result is a generic type used for storing the result of + // a find operation. The way in which this result is stored will + // vary based on whether the number of requested return values + // is 1, 2, or 3. + // Each instantiation of ffind_result must support a couple different + // operations. It is constructed with a dim_vector which indicates + // the size of the return vectors (the size of all return values for + // find is the same). It supports an add() operation which + // the generic "find" method will call to add an element to the return + // values. "add" takes an index and an nz-iterator for the matrix type + // being searched (from nz-iterators.h) + // Finally it supports get_list which returns an octave_value_list of the + // return values (1,2, or 3 depending on the nargout template argument) - Array idx; - if (n_to_find >= 0) - idx = nda.find (n_to_find, direction == -1); - else - idx = nda.find (); + template + struct ffind_result; - // The maximum element is always at the end. - octave_idx_type iext = idx.is_empty () ? 0 : idx.xelem (idx.numel () - 1) + 1; + template + struct ffind_result<1, T> + { + ffind_result (void) { } + ffind_result (const dim_vector& nnz) : res (nnz) { } - switch (nargout) + Array res; + + template + void + add (octave_idx_type place, const iter_t& iter) { - default: - case 3: - retval(2) = Array (nda.index (idx_vector (idx))); - // Fall through! - - case 2: - { - Array jdx (idx.dims ()); - octave_idx_type n = idx.length (); - octave_idx_type nr = nda.rows (); - for (octave_idx_type i = 0; i < n; i++) - { - jdx.xelem (i) = idx.xelem (i) / nr; - idx.xelem (i) %= nr; - } - iext = -1; - retval(1) = idx_vector (jdx, -1); - } - // Fall through! - - case 1: - case 0: - retval(0) = idx_vector (idx, iext); - break; + res.xelem (place) = iter.interp_idx (); } - return retval; -} - -template -octave_value_list -find_nonzero_elem_idx (const Sparse& v, int nargout, - octave_idx_type n_to_find, int direction) -{ - octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); - - octave_idx_type nr = v.rows (); - octave_idx_type nc = v.cols (); - octave_idx_type nz = v.nnz (); - - // Search in the default range. - octave_idx_type start_nc = -1; - octave_idx_type end_nc = -1; - octave_idx_type count; - - // Search for the range to search - if (n_to_find < 0) - { - start_nc = 0; - end_nc = nc; - n_to_find = nz; - count = nz; - } - else if (direction > 0) + octave_value_list + get_list (void) { - for (octave_idx_type j = 0; j < nc; j++) - { - OCTAVE_QUIT; - if (v.cidx (j) == 0 && v.cidx (j+1) != 0) - start_nc = j; - if (v.cidx (j+1) >= n_to_find) - { - end_nc = j + 1; - break; - } - } + return octave_value_list (octave_value (res)); } - else + }; + + template + struct ffind_result<2, T> + { + ffind_result (void) { } + ffind_result (const dim_vector& nnz) : rescol (nnz), resrow (nnz) { } + + Array rescol; + Array resrow; + + template + void + add (octave_idx_type place, const iter_t& iter) { - for (octave_idx_type j = nc; j > 0; j--) - { - OCTAVE_QUIT; - if (v.cidx (j) == nz && v.cidx (j-1) != nz) - end_nc = j; - if (nz - v.cidx (j-1) >= n_to_find) - { - start_nc = j - 1; - break; - } - } - } - - count = (n_to_find > v.cidx (end_nc) - v.cidx (start_nc) ? - v.cidx (end_nc) - v.cidx (start_nc) : n_to_find); - - octave_idx_type result_nr; - octave_idx_type result_nc; - - // Default case is to return a column vector, however, if the original - // argument was a row vector, then force return of a row vector. - if (nr == 1) - { - result_nr = 1; - result_nc = count; - } - else - { - result_nr = count; - result_nc = 1; + rescol.xelem (place) = to_interp_idx (iter.col ()); + resrow.xelem (place) = to_interp_idx (iter.row ()); } - Matrix idx (result_nr, result_nc); - - Matrix i_idx (result_nr, result_nc); - Matrix j_idx (result_nr, result_nc); - - Array val (dim_vector (result_nr, result_nc)); - - if (count > 0) + octave_value_list + get_list (void) { - // Search for elements to return. Only search the region where there - // are elements to be found using the count that we want to find. - for (octave_idx_type j = start_nc, cx = 0; j < end_nc; j++) - for (octave_idx_type i = v.cidx (j); i < v.cidx (j+1); i++) - { - OCTAVE_QUIT; - if (direction < 0 && i < nz - count) - continue; - i_idx(cx) = static_cast (v.ridx (i) + 1); - j_idx(cx) = static_cast (j + 1); - idx(cx) = j * nr + v.ridx (i) + 1; - val(cx) = v.data(i); - cx++; - if (cx == count) - break; - } + octave_value_list res (2); + res.xelem (0) = resrow; + res.xelem (1) = rescol; + return res; } - else - { - // No items found. Fixup return dimensions for Matlab compatibility. - // The behavior to match is documented in Array.cc (Array::find). - if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1)) - { - idx.resize (0, 0); + }; - i_idx.resize (0, 0); - j_idx.resize (0, 0); - - val.resize (dim_vector (0, 0)); - } - } - - switch (nargout) - { - case 0: - case 1: - retval(0) = idx; - break; + template + struct ffind_result<3, T> + { + ffind_result (void) { } + ffind_result (const dim_vector& nnz) : + rescol (nnz), resrow (nnz), elems (nnz) { } + Array rescol; + Array resrow; + Array elems; - case 5: - retval(4) = nc; - // Fall through - - case 4: - retval(3) = nr; - // Fall through - - case 3: - retval(2) = val; - // Fall through! - - case 2: - retval(1) = j_idx; - retval(0) = i_idx; - break; - - default: - panic_impossible (); - break; + template + void + add (octave_idx_type place, const iter_t& iter) + { + rescol.xelem (place) = to_interp_idx (iter.col ()); + resrow.xelem (place) = to_interp_idx (iter.row ()); + elems.xelem (place) = iter.data (); } - return retval; -} - -octave_value_list -find_nonzero_elem_idx (const PermMatrix& v, int nargout, - octave_idx_type n_to_find, int direction) -{ - // There are far fewer special cases to handle for a PermMatrix. - octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); - - octave_idx_type nr = v.rows (); - octave_idx_type nc = v.cols (); - octave_idx_type start_nc, count; - - // Determine the range to search. - if (n_to_find < 0 || n_to_find >= nc) - { - start_nc = 0; - count = nc; - } - else if (direction > 0) + octave_value_list + get_list (void) { - start_nc = 0; - count = n_to_find; - } - else - { - start_nc = nc - n_to_find; - count = n_to_find; - } - - Matrix idx (count, 1); - Matrix i_idx (count, 1); - Matrix j_idx (count, 1); - // Every value is 1. - Array val (dim_vector (count, 1), 1.0); - - if (count > 0) - { - const Array& p = v.col_perm_vec (); - for (octave_idx_type k = 0; k < count; k++) - { - OCTAVE_QUIT; - const octave_idx_type j = start_nc + k; - const octave_idx_type i = p(j); - i_idx(k) = static_cast (1+i); - j_idx(k) = static_cast (1+j); - idx(k) = j * nc + i + 1; - } + octave_value_list res (3); + res.xelem (0) = resrow; + res.xelem (1) = rescol; + res.xelem (2) = elems; + return res; } - else - { - // FIXME: Is this case even possible? A scalar permutation matrix seems - // to devolve to a scalar full matrix, at least from the Octave command - // line. Perhaps this function could be called internally from C++ with - // such a matrix. - // No items found. Fixup return dimensions for Matlab compatibility. - // The behavior to match is documented in Array.cc (Array::find). - if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1)) - { - idx.resize (0, 0); + }; - i_idx.resize (0, 0); - j_idx.resize (0, 0); + // Calls the to_R template method in "find.h" which + // in turn will fill in resvec of type ffind_result (see above). + // This is generic enough to work for any matrix type M + template + octave_value_list + call (const M& v, octave_idx_type n_to_find) + { + ffind_result resvec; + dir_handler dirc; + resvec = + find_to_R > (dirc, + v, n_to_find); - val.resize (dim_vector (0, 0)); - } - } + return resvec.get_list (); + } - switch (nargout) - { - case 0: - case 1: - retval(0) = idx; - break; - - case 5: - retval(4) = nc; - // Fall through - - case 4: - retval(3) = nc; - // Fall through + template + octave_value_list + dir_to_template (const M& v, octave_idx_type n_to_find, direction dir) + { + switch (dir) + { + case BACKWARD: + return call (v, n_to_find); + case FORWARD: + return call (v, n_to_find); + default: + panic_impossible (); + } + return octave_value_list (); + } - case 3: - retval(2) = val; - // Fall through! + template + octave_value_list + nargout_to_template (const M& v, int nargout, octave_idx_type n_to_find, + direction dir) + { + switch (nargout) + { + case 1: + return dir_to_template<1> (v, n_to_find, dir); + case 2: + return dir_to_template<2> (v, n_to_find, dir); + case 3: + return dir_to_template<3> (v, n_to_find, dir); + default: + panic_impossible (); // Checked by *** in Ffind + } + return octave_value_list (); + } - case 2: - retval(1) = j_idx; - retval(0) = i_idx; - break; + template + octave_value_list + find_templated (const M& v, int nargout, octave_idx_type n_to_find, + direction dir) + { + return nargout_to_template (v, nargout, n_to_find, dir); + } - default: - panic_impossible (); - break; - } - - return retval; } DEFUN (find, args, nargout, @@ -399,6 +265,12 @@ return retval; } + // *** + if (nargout < 1) + nargout = 1; + else if (nargout > 3) + nargout = 3; + // Setup the default options. octave_idx_type n_to_find = -1; if (nargin > 1) @@ -414,23 +286,22 @@ n_to_find = val; } - // Direction to do the searching (1 == forward, -1 == reverse). - int direction = 1; + // Direction to do the searching. + direction dir = FORWARD; if (nargin > 2) { - direction = 0; - std::string s_arg = args(2).string_value (); - if (! error_state) + if (error_state) { - if (s_arg == "first") - direction = 1; - else if (s_arg == "last") - direction = -1; + error ("find: DIRECTION must be \"first\" or \"last\""); + return retval; } - - if (direction == 0) + if (s_arg == "first") + dir = FORWARD; + else if (s_arg == "last") + dir = BACKWARD; + else { error ("find: DIRECTION must be \"first\" or \"last\""); return retval; @@ -446,10 +317,9 @@ SparseBoolMatrix v = arg.sparse_bool_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } - else if (nargout <= 1 && n_to_find == -1 && direction == 1) + else if (nargout <= 1 && n_to_find == -1) { // This case is equivalent to extracting indices from a logical // matrix. Try to reuse the possibly cached index vector. @@ -460,20 +330,18 @@ boolNDArray v = arg.bool_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } } else if (arg.is_integer_type ()) { -#define DO_INT_BRANCH(INTT) \ - else if (arg.is_ ## INTT ## _type ()) \ - { \ - INTT ## NDArray v = arg.INTT ## _array_value (); \ - \ - if (! error_state) \ - retval = find_nonzero_elem_idx (v, nargout, \ - n_to_find, direction);\ +#define DO_INT_BRANCH(INTT) \ + else if (arg.is_ ## INTT ## _type ()) \ + { \ + INTT ## NDArray v = arg.INTT ## _array_value (); \ + \ + if (! error_state) \ + retval = find::find_templated (v, nargout, n_to_find, dir); \ } if (false) @@ -496,16 +364,14 @@ SparseMatrix v = arg.sparse_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } else if (arg.is_complex_type ()) { SparseComplexMatrix v = arg.sparse_complex_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (v, nargout, - n_to_find, direction); + retval = find::find_templated (v, nargout, n_to_find, dir); } else gripe_wrong_type_arg ("find", arg); @@ -515,14 +381,14 @@ PermMatrix P = arg.perm_matrix_value (); if (! error_state) - retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction); + retval = find::find_templated (P, nargout, n_to_find, dir); } else if (arg.is_string ()) { charNDArray chnda = arg.char_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (chnda, nargout, n_to_find, direction); + retval = find::find_templated (chnda, nargout, n_to_find, dir); } else if (arg.is_single_type ()) { @@ -531,16 +397,14 @@ FloatNDArray nda = arg.float_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (nda, nargout, n_to_find, - direction); + retval = find::find_templated (nda, nargout, n_to_find, dir); } else if (arg.is_complex_type ()) { FloatComplexNDArray cnda = arg.float_complex_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, - direction); + retval = find::find_templated (cnda, nargout, n_to_find, dir); } } else if (arg.is_real_type ()) @@ -548,14 +412,14 @@ NDArray nda = arg.array_value (); if (! error_state) - retval = find_nonzero_elem_idx (nda, nargout, n_to_find, direction); + retval = find::find_templated (nda, nargout, n_to_find, dir); } else if (arg.is_complex_type ()) { ComplexNDArray cnda = arg.complex_array_value (); if (! error_state) - retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, direction); + retval = find::find_templated (cnda, nargout, n_to_find, dir); } else gripe_wrong_type_arg ("find", arg); @@ -611,5 +475,11 @@ %!assert (find ([2 0 1 0 5 0], Inf), [1, 3, 5]) %!assert (find ([2 0 1 0 5 0], Inf, "last"), [1, 3, 5]) +%!test +%! x = sparse(100000, 30000); +%! x(end, end) = 1; +%! i = find(x); +%! assert (i == 3e09); + %!error find () */ diff -r d00f6b09258f -r 2e0613dadfee liboctave/array/Array.cc --- a/liboctave/array/Array.cc Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/array/Array.cc Tue Jun 17 16:41:11 2014 -0600 @@ -2242,87 +2242,6 @@ } template -Array -Array::find (octave_idx_type n, bool backward) const -{ - Array retval; - const T *src = data (); - octave_idx_type nel = nelem (); - const T zero = T (); - if (n < 0 || n >= nel) - { - // We want all elements, which means we'll almost surely need - // to resize. So count first, then allocate array of exact size. - octave_idx_type cnt = 0; - for (octave_idx_type i = 0; i < nel; i++) - cnt += src[i] != zero; - - retval.clear (cnt, 1); - octave_idx_type *dest = retval.fortran_vec (); - for (octave_idx_type i = 0; i < nel; i++) - if (src[i] != zero) *dest++ = i; - } - else - { - // We want a fixed max number of elements, usually small. So be - // optimistic, alloc the array in advance, and then resize if - // needed. - retval.clear (n, 1); - if (backward) - { - // Do the search as a series of successive single-element searches. - octave_idx_type k = 0; - octave_idx_type l = nel - 1; - for (; k < n; k++) - { - for (; l >= 0 && src[l] == zero; l--) ; - if (l >= 0) - retval(k) = l--; - else - break; - } - if (k < n) - retval.resize2 (k, 1); - octave_idx_type *rdata = retval.fortran_vec (); - std::reverse (rdata, rdata + k); - } - else - { - // Do the search as a series of successive single-element searches. - octave_idx_type k = 0; - octave_idx_type l = 0; - for (; k < n; k++) - { - for (; l != nel && src[l] == zero; l++) ; - if (l != nel) - retval(k) = l++; - else - break; - } - if (k < n) - retval.resize2 (k, 1); - } - } - - // Fixup return dimensions, for Matlab compatibility. - // find (zeros (0,0)) -> zeros (0,0) - // find (zeros (1,0)) -> zeros (1,0) - // find (zeros (0,1)) -> zeros (0,1) - // find (zeros (0,X)) -> zeros (0,1) - // find (zeros (1,1)) -> zeros (0,0) !!!! WHY? - // find (zeros (0,1,0)) -> zeros (0,0) - // find (zeros (0,1,0,1)) -> zeros (0,0) etc - - if ((numel () == 1 && retval.is_empty ()) - || (rows () == 0 && dims ().numel (1) == 0)) - retval.dimensions = dim_vector (); - else if (rows () == 1 && ndims () == 2) - retval.dimensions = dim_vector (1, retval.length ()); - - return retval; -} - -template Array Array::nth_element (const idx_vector& n, int dim) const { @@ -2516,9 +2435,6 @@ template <> octave_idx_type \ Array::nnz (void) const\ { return 0; } \ -template <> Array \ -Array::find (octave_idx_type, bool) const\ -{ return Array (); } \ \ template <> Array \ Array::nth_element (const idx_vector&, int) const { return Array (); } \ diff -r d00f6b09258f -r 2e0613dadfee liboctave/array/Array.h --- a/liboctave/array/Array.h Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/array/Array.h Tue Jun 17 16:41:11 2014 -0600 @@ -41,6 +41,11 @@ #include "oct-mem.h" #include "oct-refcount.h" +//Forward declaration for array_iterator, +//the nonzero-iterator type for Array (in nz_iterator.h) +template +class array_iterator; + // One dimensional array class. Handles the reference counting for // all the derived classes. @@ -123,6 +128,8 @@ typedef T element_type; + typedef array_iterator iter_type; + typedef typename ref_param::type crefT; typedef bool (*compare_fcn_type) (typename ref_param::type, @@ -611,11 +618,6 @@ // Count nonzero elements. octave_idx_type nnz (void) const; - // Find indices of (at most n) nonzero elements. If n is specified, backward - // specifies search from backward. - Array find (octave_idx_type n = -1, - bool backward = false) const; - // Returns the n-th element in increasing order, using the same ordering as // used for sort. n can either be a scalar index or a contiguous range. Array nth_element (const idx_vector& n, int dim = 0) const; diff -r d00f6b09258f -r 2e0613dadfee liboctave/array/DiagArray2.h --- a/liboctave/array/DiagArray2.h Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/array/DiagArray2.h Tue Jun 17 16:41:11 2014 -0600 @@ -31,6 +31,11 @@ #include "Array.h" +//Forward declaration for diag_iterator, +//the nonzero-iterator type for DiagArray2 (in nz_iterator.h) +template +class diag_iterator; + // Array is inherited privately so that some methods, like index, don't // produce unexpected results. @@ -45,6 +50,8 @@ using typename Array::element_type; + typedef diag_iterator iter_type; + DiagArray2 (void) : Array (), d1 (0), d2 (0) { } diff -r d00f6b09258f -r 2e0613dadfee liboctave/array/PermMatrix.h --- a/liboctave/array/PermMatrix.h Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/array/PermMatrix.h Tue Jun 17 16:41:11 2014 -0600 @@ -26,6 +26,10 @@ #include "Array.h" #include "mx-defs.h" +//Forward declaration for perm_iterator, +//the nonzero-iterator type for PermMatrix (in nz_iterator.h) +class perm_iterator; + // Array is inherited privately so that some methods, like index, don't // produce unexpected results. @@ -33,6 +37,9 @@ { public: + typedef bool element_type; + typedef perm_iterator iter_type; + PermMatrix (void) : Array () { } PermMatrix (octave_idx_type n); @@ -63,6 +70,7 @@ { return perm_length (); } octave_idx_type nelem (void) const { return dim1 () * dim2 (); } octave_idx_type numel (void) const { return nelem (); } + octave_idx_type nnz (void) const { return perm_length (); } size_t byte_size (void) const { return Array::byte_size (); } @@ -104,6 +112,9 @@ bool is_col_perm (void) const { return true; } bool is_row_perm (void) const { return false; } + octave_idx_type perm_elem(octave_idx_type i) const + { return Array::elem (i); } + void print_info (std::ostream& os, const std::string& prefix) const { Array::print_info (os, prefix); } diff -r d00f6b09258f -r 2e0613dadfee liboctave/array/Sparse.h --- a/liboctave/array/Sparse.h Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/array/Sparse.h Tue Jun 17 16:41:11 2014 -0600 @@ -43,6 +43,11 @@ class idx_vector; class PermMatrix; +//Forward declaration for sparse_iterator, +//the nonzero-iterator type for Sparse (in nz_iterator.h) +template +class sparse_iterator; + // Two dimensional sparse class. Handles the reference counting for // all the derived classes. @@ -54,6 +59,8 @@ typedef T element_type; + typedef sparse_iterator iter_type; + protected: //-------------------------------------------------------------------- // The real representation of all Sparse arrays. diff -r d00f6b09258f -r 2e0613dadfee liboctave/array/dim-vector.h --- a/liboctave/array/dim-vector.h Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/array/dim-vector.h Tue Jun 17 16:41:11 2014 -0600 @@ -339,7 +339,9 @@ // Return the number of elements that a matrix with this dimension // vector would have, NOT the number of dimensions (elements in the - // dimension vector). + // dimension vector). If given an argument n, returns the number of + // elements an array with dimensions dim[n..end] would have + // (ex. numel(1) returns the number of columns) octave_idx_type numel (int n = 0) const { diff -r d00f6b09258f -r 2e0613dadfee liboctave/array/idx-vector.cc --- a/liboctave/array/idx-vector.cc Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/array/idx-vector.cc Tue Jun 17 16:41:11 2014 -0600 @@ -39,6 +39,7 @@ #include "oct-locbuf.h" #include "lo-error.h" #include "lo-mappers.h" +#include "find.h" static void gripe_invalid_range (void) @@ -753,7 +754,7 @@ idx_vector::idx_mask_rep::as_array (void) { if (aowner) - return aowner->find ().reshape (orig_dims); + return find::find (*aowner).reshape (orig_dims); else { Array retval (orig_dims); diff -r d00f6b09258f -r 2e0613dadfee liboctave/util/direction.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/direction.h Tue Jun 17 16:41:11 2014 -0600 @@ -0,0 +1,76 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ +#if !defined (octave_direction_h) +#define octave_direction_h 1 + +// Simple generic parameterized functions for stepping "forward" or "backward" +// FORWARD and BACKWARD are elements of the "direction enum +// +enum direction +{ + FORWARD = 1, BACKWARD = -1 +}; + +// A struct with two overloaded functions: begin(lo, hi) and +// is_ended (i, lo, hi) where i is assumed to be stepping through the range +// [lo, hi) (inclusive, exclusive). begin() returns the initial value for i +// and is_ended (i, lo, hi) returns true if i has stepped past the end of the +// range. To increment i in the proper direction, one can simply say i += dir +// (dir is implicitly cast to an int either 1 or -1 for FORWARD and BACKWARD +// respectively). +// When lo = 0, one can instead use the 1- and 2- argument variants of begin +// and is_ended respectively +template +struct dir_handler +{ + octave_idx_type + begin (octave_idx_type size) const + { + return begin (0, size); + } + + octave_idx_type + begin (octave_idx_type lo, octave_idx_type hi) const + { + switch (dir) { + case FORWARD: return lo; + case BACKWARD: return hi - 1; + } + } + + bool + is_ended (octave_idx_type i, octave_idx_type size) const + { + return is_ended (i, 0, size); + } + + bool + is_ended (octave_idx_type i, octave_idx_type lo, octave_idx_type hi) const + { + switch (dir) { + case FORWARD: return i >= hi; + case BACKWARD: return i < lo; + } + } +}; + +#endif diff -r d00f6b09258f -r 2e0613dadfee liboctave/util/find.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/find.h Tue Jun 17 16:41:11 2014 -0600 @@ -0,0 +1,144 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ +#if !defined (octave_find_h) +#define octave_find_h 1 + +#include "Array.h" + +#include "interp-idx.h" +#include "min-with-nnz.h" +#include "nz-iterators.h" +#include "direction.h" + +namespace find +{ + // This struct mostly matches the signature of ffind_result (see find.cc) + // except without a get_list function. Internal calls to "find" from + // liboctave should call the "find" method which in turn constructs one + // of these as a value to be filled. + // It should not be used anywhere else + + struct find_result + { + Array res; + template + void + add (octave_idx_type place, const array_iterator& iter) + { + res.xelem (place) = iter.flat_idx (); + } + + find_result (void) { } + find_result (const dim_vector& nnz) : res (nnz) { } + }; + + // This is the generic "find" method for all matrix types. All calls to + // "find" get routed here one way or another as this method contains logic + // for determining the proper output dimensions and handling forward/reversed, + // n_to_find etc. + // It would be unwise to try and duplicate this logic elsewhere as it's fairly + // nuanced and unintuitive (but Matlab compatible). + + template + R + find_to_R (dir_handler dirc, const MT& v, octave_idx_type n_to_find) + { + typedef typename MT::iter_type iter_t; + + octave_idx_type numres; + if (n_to_find == -1) + numres = v.nnz (); + else + numres = min_with_nnz (v, n_to_find); + + const dim_vector& dv = v.dims (); + octave_idx_type col_len = dv.elem (0); + octave_idx_type cols = dv.numel (1); + + // Fixup return dimensions, for Matlab compatibility. + // find (zeros (0,0)) -> zeros (0,0) + // find (zeros (1,0)) -> zeros (1,0) + // find (zeros (0,1)) -> zeros (0,1) + // find (zeros (0,X)) -> zeros (0,1) + // find (zeros (1,1)) -> zeros (0,0) !!!! WHY? + // find (zeros (0,1,0)) -> zeros (0,0) + // find (zeros (0,1,0,1)) -> zeros (0,0) etc + + dim_vector res_dims; + + if ((cols == 1 && col_len == 1 && numres == 0) + || (col_len == 0 && cols == 0)) + res_dims = dim_vector (0, 0); + else if (col_len == 1) + res_dims = dim_vector (1, numres); + else + res_dims = dim_vector (numres, 1); + + iter_t iterator (v); + R resvec (res_dims); + octave_idx_type count = dirc.begin (numres); + + for (iterator.begin (dirc); + !iterator.finished (dirc) && !dirc.is_ended (count, numres); + iterator.step (dirc), count += dir) + { + resvec.add (count, iterator); + } + + return resvec; + } + + // The method call used for internal liboctave calls to find. It constructs + // an find_result and fills that rather than a res_type_of (see "find.cc"). + // Note that this method only supports full arrays because octave_idx_type + // will generally overflow if used to represent the nonzero indices of sparse + // arrays. Instead consider using the sparse-iterator class from + // nz-iterators.h to avoid this problem + + template + Array + find (const Array& v, octave_idx_type n_to_find = -1, direction dir = + FORWARD) + { + find_result resvec; + + switch (dir) + { + case FORWARD: + dir_handler dirc_forward; + resvec = find_to_R (dirc_forward, v, n_to_find); + break; + case BACKWARD: + dir_handler dirc_back; + resvec = find_to_R (dirc_back, v, n_to_find); + break; + default: + liboctave_fatal( + "find: unknown direction: %d; must be FORWARD or BACKWARD", + dir); + } + + return resvec.res; + } +} + +#endif diff -r d00f6b09258f -r 2e0613dadfee liboctave/util/interp-idx.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/interp-idx.h Tue Jun 17 16:41:11 2014 -0600 @@ -0,0 +1,44 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ +#if !defined (octave_interp_idx_h) +#define octave_interp_idx_h 1 + +// Simple method for converting between C++ octave_idx_type and +// Octave data-types (convert to double and add one). +inline double +to_interp_idx (octave_idx_type idx) +{ + return idx + 1.L; +} + +// Simple method for taking a row-column pair together with the matrix +// height and returning the corresponding index as an octave data-type +// (note that for large heights, there's a risk of losing precision. +// This method will not overflow or throw a bad alloc, it will simply +// choose the nearest possible double-value to the proper index). +inline double +to_interp_idx (octave_idx_type row, octave_idx_type col, octave_idx_type height) +{ + return col * static_cast (height) + row + 1; +} + +#endif diff -r d00f6b09258f -r 2e0613dadfee liboctave/util/min-with-nnz.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/min-with-nnz.h Tue Jun 17 16:41:11 2014 -0600 @@ -0,0 +1,59 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ +#if !defined (octave_min_with_nnz_h) +#define octave_min_with_nnz_h 1 + +#include "Sparse.h" +#include "direction.h" + +// Generic efficient methods for finding the equivalent of +// min(arr.nnz(), minwith). In case arr.nnz() happens to run +// in O(arr.numel()) time, this operation will short-circuit and return +// as soon as it counts minwith nonzero elements. +// But if arr has a constant-time nnz() operation, then this will use +// that instead and run in constant-time. + +template +octave_idx_type +min_with_nnz (M arr, octave_idx_type minwith) +{ + octave_idx_type count; + typename M::iter_type iter (arr); + dir_handler dirc; + for (iter.begin (dirc), count = 0; !iter.finished (dirc) && count < minwith; + iter.step (dirc), ++count); + return count; +} + +template +octave_idx_type +min_with_nnz (Sparse arr, octave_idx_type minwith) +{ + return std::min (arr.nnz (), minwith); +} + +// FIXME: I want to add one for PermMatrix, but for some reason C++ doesn't +// like it when you overload templated functions with non-templated ones. +// Except that PermMatrix is not templated. Anyone know how to handle that? +// Should I add a dummy template parameter? (that seems like quite an ugly hack) + +#endif diff -r d00f6b09258f -r 2e0613dadfee liboctave/util/module.mk --- a/liboctave/util/module.mk Mon Aug 11 09:39:45 2014 -0700 +++ b/liboctave/util/module.mk Tue Jun 17 16:41:11 2014 -0600 @@ -9,8 +9,11 @@ util/cmd-edit.h \ util/cmd-hist.h \ util/data-conv.h \ + util/direction.h \ + util/find.h \ util/functor.h \ util/glob-match.h \ + util/interp-idx.h \ util/lo-array-gripes.h \ util/lo-cutils.h \ util/lo-ieee.h \ @@ -18,6 +21,8 @@ util/lo-math.h \ util/lo-traits.h \ util/lo-utils.h \ + util/min-with-nnz.h \ + util/nz-iterators.h \ util/oct-alloc.h \ util/oct-base64.h \ util/oct-binmap.h \ diff -r d00f6b09258f -r 2e0613dadfee liboctave/util/nz-iterators.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/nz-iterators.h Tue Jun 17 16:41:11 2014 -0600 @@ -0,0 +1,403 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ +#if !defined (octave_nz_iterators_h) +#define octave_nz_iterators_h 1 + +#include "interp-idx.h" +#include "oct-inttypes.h" +#include "Array.h" +#include "DiagArray2.h" +#include "PermMatrix.h" +#include "Sparse.h" +#include "direction.h" + +// This file contains generic column-major iterators over +// the nonzero elements of any array or matrix. If you have a matrix mat +// of type M, you can construct the proper iterator type using +// M::iter_type iter(mat) and iter will iterate efficiently (forwards +// or backwards) over the nonzero elements of mat. +// +// The parameter T is the element-type except for PermMatrix where +// the element type is always bool +// +// Use a dir_handler to indicate which direction you intend to iterate. +// (see step-dir.h). begin() resets the iterator to the beginning or +// end of the matrix (for dir_handler and respectively). +// finished(dirc) indicates whether the iterators has finished traversing +// the nonzero elements. step(dirc) steps from one element to the next. +// +// You can, for instance, use a for-loop as follows: +// +// typedef M::iter_type iter_t; +// typedef M::element_type T; +// +// iter_t iter(mat); +// dir_handler<1> dirc; +// +// for(iter.begin (dirc); !iter.finished (dirc); iter.step (dirc)) +// { +// octave_idx_type row = iter.row(); +// octave_idx_type col = iter.col(); +// double doub_index = iter.interp_index (); +// T elem = iter.data(); +// // ... Do something with these +// } +// +// Note that array_iter for indexing over full matrices also includes +// a iter.flat_index () method which returns an octave_idx_type. +// +// The other iterators to not have a flat_index() method because they +// risk overflowing octave_idx_type. It is recommended you take care +// to implement your function in a way that accounts for this problem. +// +// FIXME: I'd like to add in these +// default no-parameter versions of +// begin() and step() to each of the +// classes. But the C++ compiler complains +// because apparently I'm not allowed to overload +// templated methods with non-templated ones. Any +// ideas for work-arounds? +// +//#define INCLUDE_DEFAULT_STEPS \ +// void begin (void) \ +// { \ +// dir_handler dirc; \ +// begin (dirc); \ +// } \ +// void step (void) \ +// { \ +// dir_handler dirc; \ +// step (dirc); \ +// } \ +// bool finished (void) const \ +// { \ +// dir_handler dirc; \ +// return finished (dirc); \ +// } + +// A generic method for checking if some element of a matrix with +// element type T is zero. +template +bool +is_zero (T t) +{ + return t == static_cast (0); +} + +// An iterator over full arrays. When the number of dimensions exceeds +// 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1) +// +// This mimics the behavior of the "find" method (both in Octave and Matlab) +// on many-dimensional matrices. + +template +class array_iterator +{ +private: + const Array& mat; + + //Actual total number of columns = mat.dims().numel(1) + //can be different from length of row dimension + const octave_idx_type totcols; + const octave_idx_type numels; + + octave_idx_type coli; + octave_idx_type rowj; + octave_idx_type my_idx; + + template + void + step_once (dir_handler dirc) + { + my_idx += dir; + rowj += dir; + if (dirc.is_ended (rowj, mat.rows ())) + { + rowj = dirc.begin (mat.rows ()); + coli += dir; + } + } + + template + void + move_to_nz (dir_handler dirc) + { + while (!finished (dirc) && is_zero (data ())) + { + step_once (dirc); + } + } + + +public: + array_iterator (const Array& arg_mat) + : mat (arg_mat), totcols (arg_mat.dims ().numel (1)), numels ( + totcols * arg_mat.rows ()) + { + dir_handler dirc; + begin (dirc); + } + + template + void + begin (dir_handler dirc) + { + coli = dirc.begin (totcols); + rowj = dirc.begin (mat.rows ()); + my_idx = dirc.begin (mat.numel ()); + move_to_nz (dirc); + } + + octave_idx_type + col (void) const + { + return coli; + } + octave_idx_type + row (void) const + { + return rowj; + } + double + interp_idx (void) const + { + return to_interp_idx (my_idx); + } + octave_idx_type + flat_idx (void) const + { + return my_idx; + } + T + data (void) const + { + return mat.elem (my_idx); + } + + template + void + step (dir_handler dirc) + { + step_once (dirc); + move_to_nz (dirc); + } + template + bool + finished (dir_handler dirc) const + { + return dirc.is_ended (my_idx, numels); + } +}; + +template +class sparse_iterator +{ +private: + const Sparse& mat; + octave_idx_type coli; + octave_idx_type my_idx; + + template + void + adjust_col (dir_handler dirc) + { + while (!finished (dirc) + && dirc.is_ended (my_idx, mat.cidx (coli), mat.cidx (coli + 1))) + coli += dir; + } + +public: + sparse_iterator (const Sparse& arg_mat) : + mat (arg_mat) + { + dir_handler dirc; + begin (dirc); + } + + template + void + begin (dir_handler dirc) + { + coli = dirc.begin (mat.cols ()); + my_idx = dirc.begin (mat.nnz ()); + adjust_col (dirc); + } + + double + interp_idx (void) const + { + return to_interp_idx (row (), col (), mat.rows ()); + } + octave_idx_type + col (void) const + { + return coli; + } + octave_idx_type + row (void) const + { + return mat.ridx (my_idx); + } + T + data (void) const + { + return mat.data (my_idx); + } + template + void + step (dir_handler dirc) + { + my_idx += dir; + adjust_col (dirc); + } + template + bool + finished (dir_handler dirc) const + { + return dirc.is_ended (coli, mat.cols ()); + } +}; + +template +class diag_iterator +{ +private: + const DiagArray2& mat; + octave_idx_type my_idx; + + template + void + move_to_nz (dir_handler dirc) + { + while (!finished (dirc) && is_zero (data ())) + { + my_idx += dir; + } + } + +public: + diag_iterator (const DiagArray2& arg_mat) : + mat (arg_mat) + { + dir_handler dirc; + begin (dirc); + } + + template + void + begin (dir_handler dirc) + { + my_idx = dirc.begin (mat.diag_length ()); + move_to_nz (dirc); + } + + double + interp_idx (void) const + { + return to_interp_idx (row (), col (), mat.rows ()); + } + octave_idx_type + col (void) const + { + return my_idx; + } + octave_idx_type + row (void) const + { + return my_idx; + } + T + data (void) const + { + return mat.dgelem (my_idx); + } + template + void + step (dir_handler dirc) + { + my_idx += dir; + move_to_nz (dirc); + } + template + bool + finished (dir_handler dirc) const + { + return dirc.is_ended (my_idx, mat.diag_length ()); + } +}; + +class perm_iterator +{ +private: + const PermMatrix& mat; + octave_idx_type my_idx; + +public: + perm_iterator (const PermMatrix& arg_mat) : + mat (arg_mat) + { + dir_handler dirc; + begin (dirc); + } + + template + void + begin (dir_handler dirc) + { + my_idx = dirc.begin (mat.cols ()); + } + + octave_idx_type + interp_idx (void) const + { + return to_interp_idx (row (), col (), mat.rows ()); + } + octave_idx_type + col (void) const + { + return my_idx; + } + octave_idx_type + row (void) const + { + return mat.perm_elem (my_idx); + } + bool + data (void) const + { + return true; + } + template + void + step (dir_handler) + { + my_idx += dir; + } + template + bool + finished (dir_handler dirc) const + { + return dirc.is_ended (my_idx, mat.rows ()); + } +}; + +#endif diff -r d00f6b09258f -r 2e0613dadfee test/build-sparse-tests.sh --- a/test/build-sparse-tests.sh Mon Aug 11 09:39:45 2014 -0700 +++ b/test/build-sparse-tests.sh Tue Jun 17 16:41:11 2014 -0600 @@ -569,7 +569,8 @@ %!assert (as==af) %!assert (af==as) %!test -%! [ii,jj,vv,nr,nc] = find (as); +%! [ii,jj,vv] = find (as); +%! [nr,nc] = size (as); %! assert (af, full (sparse (ii,jj,vv,nr,nc))); %!assert (nnz (as), sum (af(:)!=0)) %!assert (nnz (as), nnz (af)) @@ -598,7 +599,8 @@ %! x = sparse (i,j,v,m,n); %! assert (x, as); %!test -%! [i,j,v,m,n] = find (as); +%! [i,j,v] = find (as); +%! [m,n] = size (as); %! x = sparse (i,j,v,m,n); %! assert (x, as); %!assert (issparse (horzcat (as,as)));