# HG changeset patch # User Jaroslav Hajek # Date 1238071846 -3600 # Node ID e67dc11ed6e8d0e90f43c36ff8825033e73ddbd4 # Parent 9a46ba093db49db28a2ae4a621680733a6627430 use Array::find in find diff -r 9a46ba093db4 -r e67dc11ed6e8 src/ChangeLog --- a/src/ChangeLog Thu Mar 26 13:20:05 2009 +0100 +++ b/src/ChangeLog Thu Mar 26 13:50:46 2009 +0100 @@ -1,3 +1,12 @@ +2009-03-26 Jaroslav Hajek + + * DLD-FUNCTIONS/find.cc + (find_nonzero_elem_idx (const Array&, ...)): Simplify. + Instantiate for bool and octave_int types. + (find_nonzero_elem_idx (const Sparse&, ...)): + Instantiate for bool. + (Ffind): Handle bool and octave_int cases. + 2009-03-25 John W. Eaton * version.h (OCTAVE_VERSION): Now 3.1.55+. diff -r 9a46ba093db4 -r e67dc11ed6e8 src/DLD-FUNCTIONS/find.cc --- a/src/DLD-FUNCTIONS/find.cc Thu Mar 26 13:20:05 2009 +0100 +++ b/src/DLD-FUNCTIONS/find.cc Thu Mar 26 13:50:46 2009 +0100 @@ -43,155 +43,75 @@ { octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); - octave_idx_type count = 0; - - octave_idx_type nel = nda.nelem (); - - // Set the starting element to the correct value based on the - // direction to search. - octave_idx_type k = 0; - if (direction == -1) - k = nel - 1; - - // Search in the default range. - octave_idx_type start_el = -1; - octave_idx_type end_el = -1; - - // Search for the number of elements to return. - while (k < nel && k > -1 && n_to_find != count) - { - OCTAVE_QUIT; - - if (nda(k) != T ()) - { - end_el = k; - if (start_el == -1) - start_el = k; - count++; - } - k = k + direction; - } - - // Reverse the range if we're looking backward. - if (direction == -1) - { - octave_idx_type tmp_el = start_el; - start_el = end_el; - end_el = tmp_el; - } - // Fix an off by one error. - end_el++; - - // If the original argument was a row vector, force a row vector of - // the overall indices to be returned. But see below for scalar - // case... - - octave_idx_type result_nr = count; - octave_idx_type result_nc = 1; - - bool column_vector_arg = false; - bool scalar_arg = false; - - if (nda.ndims () == 2) - { - octave_idx_type nr = nda.rows (); - octave_idx_type nc = nda.columns (); + Array idx; + if (n_to_find >= 0) + idx = nda.find (n_to_find, direction == -1); + else + idx = nda.find (); - if (nr == 1) - { - result_nr = 1; - result_nc = count; - - scalar_arg = (nc == 1); - } - else if (nc == 1) - column_vector_arg = true; - } - - Matrix idx (result_nr, result_nc); - - Matrix i_idx (result_nr, result_nc); - Matrix j_idx (result_nr, result_nc); - - ArrayN val (dim_vector (result_nr, result_nc)); - - if (count > 0) - { - count = 0; - - octave_idx_type nr = nda.rows (); - - octave_idx_type i = 0; - - // Search for elements to return. Only search the region where - // there are elements to be found using the count that we want - // to find. + // Fixup idx dimensions, for Matlab compatibility. + // find(zeros(0,0)) -> zeros(0,0) + // find(zeros(1,0)) -> zeros(1,0) + // find(zeros(0,1)) -> zeros(0,1) + // find(zeros(0,X)) -> zeros(0,1) + // find(zeros(1,1)) -> zeros(0,0) !!!! WHY? + // find(zeros(0,1,0)) -> zeros(0,0) + // find(zeros(0,1,0,1)) -> zeros(0,0) etc + // FIXME: I don't believe this is right. Matlab seems to violate its own docs + // here, because a scalar *is* a row vector. - // For compatibility, all N-d arrays are handled as if they are - // 2-d, with the number of columns equal to "prod (dims (2:end))". - - for (k = start_el; k < end_el; k++) - { - OCTAVE_QUIT; - - if (nda(k) != T ()) - { - idx(count) = k + 1; - - octave_idx_type xr = k % nr; - i_idx(count) = xr + 1; - j_idx(count) = (k - xr) / nr + 1; - - val(count) = nda(k); - - count++; - } - - i++; - } - } - else if (scalar_arg || (nda.rows () == 0 && ! column_vector_arg)) - { - idx.resize (0, 0); - - i_idx.resize (0, 0); - j_idx.resize (0, 0); - - val.resize (dim_vector (0, 0)); - } + if ((nda.numel () == 1 && idx.is_empty ()) + || (nda.rows () == 0 && nda.dims ().numel (1) == 0)) + idx = idx.reshape (dim_vector (0, 0)); + else if (nda.rows () == 1 && nda.ndims () == 2) + idx = idx.reshape (dim_vector (1, idx.length ())); switch (nargout) { default: case 3: - retval(2) = val; + retval(2) = ArrayN (nda.index (idx_vector (idx))); // Fall through! case 2: - retval(1) = j_idx; - retval(0) = i_idx; - break; + { + Array jdx (idx.dims ()); + octave_idx_type n = idx.length (), nr = nda.rows (); + for (octave_idx_type i = 0; i < n; i++) + { + jdx.xelem (i) = idx.xelem (i) / nr; + idx.xelem (i) %= nr; + } + retval(1) = NDArray (jdx, true); + } + // Fall through! case 1: case 0: - retval(0) = idx; + retval(0) = NDArray (idx, true); break; } return retval; } -template octave_value_list find_nonzero_elem_idx (const Array&, int, - octave_idx_type, int); - -template octave_value_list find_nonzero_elem_idx (const Array&, int, - octave_idx_type, int); +#define INSTANTIATE_FIND_ARRAY(T) \ +template octave_value_list find_nonzero_elem_idx (const Array&, int, \ + octave_idx_type, int) -template octave_value_list find_nonzero_elem_idx (const Array&, int, - octave_idx_type, int); - -template octave_value_list find_nonzero_elem_idx (const Array&, - int, octave_idx_type, int); +INSTANTIATE_FIND_ARRAY(double); +INSTANTIATE_FIND_ARRAY(float); +INSTANTIATE_FIND_ARRAY(Complex); +INSTANTIATE_FIND_ARRAY(FloatComplex); +INSTANTIATE_FIND_ARRAY(bool); +INSTANTIATE_FIND_ARRAY(octave_int8); +INSTANTIATE_FIND_ARRAY(octave_int16); +INSTANTIATE_FIND_ARRAY(octave_int32); +INSTANTIATE_FIND_ARRAY(octave_int64); +INSTANTIATE_FIND_ARRAY(octave_uint8); +INSTANTIATE_FIND_ARRAY(octave_uint16); +INSTANTIATE_FIND_ARRAY(octave_uint32); +INSTANTIATE_FIND_ARRAY(octave_uint64); template octave_value_list @@ -342,6 +262,9 @@ template octave_value_list find_nonzero_elem_idx (const Sparse&, int, octave_idx_type, int); +template octave_value_list find_nonzero_elem_idx (const Sparse&, int, + octave_idx_type, int); + octave_value_list find_nonzero_elem_idx (const PermMatrix& v, int nargout, octave_idx_type n_to_find, int direction) @@ -561,7 +484,51 @@ octave_value arg = args(0); - if (arg.is_sparse_type ()) + if (arg.is_bool_type ()) + { + if (arg.is_sparse_type ()) + { + SparseBoolMatrix v = arg.sparse_bool_matrix_value (); + + if (! error_state) + retval = find_nonzero_elem_idx (v, nargout, + n_to_find, direction); + } + else + { + boolNDArray v = arg.bool_array_value (); + + if (! error_state) + retval = find_nonzero_elem_idx (v, nargout, + n_to_find, direction); + } + } + else if (arg.is_integer_type ()) + { +#define DO_INT_BRANCH(INTT) \ + else if (arg.is_ ## INTT ## _type ()) \ + { \ + INTT ## NDArray v = arg.INTT ## _array_value (); \ + \ + if (! error_state) \ + retval = find_nonzero_elem_idx (v, nargout, \ + n_to_find, direction);\ + } + + if (false) + ; + DO_INT_BRANCH (int8) + DO_INT_BRANCH (int16) + DO_INT_BRANCH (int32) + DO_INT_BRANCH (int64) + DO_INT_BRANCH (uint8) + DO_INT_BRANCH (uint16) + DO_INT_BRANCH (uint32) + DO_INT_BRANCH (uint64) + else + panic_impossible (); + } + else if (arg.is_sparse_type ()) { if (arg.is_real_type ()) {