diff libinterp/corefcn/find.cc @ 19006:2e0613dadfee draft

All calls to "find" use the same generic implementation (bug #42408, 42421) * find.cc: Rewrite. Move generic "find" logic to find.h (Ffind) : Changed calls to find_nonzero_elem_idx to find_templated Added unit test for bug #42421 * Array.cc (and .h) (Array::find): Deleted function. Replaced with find::find(Array) from find.h * Array.h: Added typedef for array_iterator (in nz-iterators.h) as Array::iter_type * DiagArray2.h: Added typedef for diag_iterator (in nz-iterators.h) as DiagArray2::iter_type * PermMatrix.h: Added typedef for perm_iterator (in nz-iterators.h) as PermMatrix::iter_type Also added typedef for bool as PermMatrix::element_type (not octave_idx_type) Added an nnz() function (which is an alias for perm_length) and a perm_elem(i) function for retrieving the ith element of the permutation * Sparse.h: Added typedef for sparse_iterator (in nz-iterators.h) as Sparse::iter_type Added a short comment documenting the the argument to the numel function * idx-vector.cc (idx_vector::idx_mask_rep::as_array): Changed Array.find to find::find(Array) (in find.h) * (new file) find.h * (new file) interp-idx.h: Simple methods for converting between interpreter index type and internal octave_idx_type/row-col pair * (new file) min-with-nnz.h: Fast methods for taking an arbitrary matrix M and an octave_idx_type n and finding min(M.nnz(), n) * (new file) nz-iterators.h: Iterators for traversing (in column-major order) the nonzero elements of any array or matrix backwards or forwards * (new file) direction.h: Generic methods for simplifying code has to deal with a "backwards or forwards" template argument * build-sparse-tests.sh: Removed 5-return-value calls to "find" in unit-tests; Admittedly this commit breaks this "feature" which was undocumented and only partially supported to begin with (ie never worked for full matrices, permutation matrices, or diagonal matrices)
author David Spies <dnspies@gmail.com>
date Tue, 17 Jun 2014 16:41:11 -0600
parents aa9ca67f09fb
children 80ca3b05d77c
line wrap: on
line diff
--- a/libinterp/corefcn/find.cc	Mon Aug 11 09:39:45 2014 -0700
+++ b/libinterp/corefcn/find.cc	Tue Jun 17 16:41:11 2014 -0600
@@ -1,6 +1,7 @@
 /*
 
 Copyright (C) 1996-2013 John W. Eaton
+Copyright (C) 2014 David Spies
 
 This file is part of Octave.
 
@@ -24,305 +25,170 @@
 #include <config.h>
 #endif
 
-#include "quit.h"
+#include "find.h"
 
 #include "defun.h"
 #include "error.h"
 #include "gripes.h"
 #include "oct-obj.h"
 
-// Find at most N_TO_FIND nonzero elements in NDA.  Search forward if
-// DIRECTION is 1, backward if it is -1.  NARGOUT is the number of
-// output arguments.  If N_TO_FIND is -1, find all nonzero elements.
-
-template <typename T>
-octave_value_list
-find_nonzero_elem_idx (const Array<T>& nda, int nargout,
-                       octave_idx_type n_to_find, int direction)
+namespace find
 {
-  octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
+  // ffind_result is a generic type used for storing the result of
+  // a find operation.  The way in which this result is stored will
+  // vary based on whether the number of requested return values
+  // is 1, 2, or 3.
+  // Each instantiation of ffind_result must support a couple different
+  // operations.  It is constructed with a dim_vector which indicates
+  // the size of the return vectors (the size of all return values for
+  // find is the same).  It supports an add() operation which
+  // the generic "find" method will call to add an element to the return
+  // values.  "add" takes an index and an nz-iterator for the matrix type
+  // being searched (from nz-iterators.h)
+  // Finally it supports get_list which returns an octave_value_list of the
+  // return values (1,2, or 3 depending on the nargout template argument)
 
-  Array<octave_idx_type> idx;
-  if (n_to_find >= 0)
-    idx = nda.find (n_to_find, direction == -1);
-  else
-    idx = nda.find ();
+  template<int nargout, typename T>
+  struct ffind_result;
 
-  // The maximum element is always at the end.
-  octave_idx_type iext = idx.is_empty () ? 0 : idx.xelem (idx.numel () - 1) + 1;
+  template<typename T>
+  struct ffind_result<1, T>
+  {
+    ffind_result (void) { }
+    ffind_result (const dim_vector& nnz) : res (nnz) { }
 
-  switch (nargout)
+    Array<double> res;
+
+    template<typename iter_t>
+    void
+    add (octave_idx_type place, const iter_t& iter)
     {
-    default:
-    case 3:
-      retval(2) = Array<T> (nda.index (idx_vector (idx)));
-      // Fall through!
-
-    case 2:
-      {
-        Array<octave_idx_type> jdx (idx.dims ());
-        octave_idx_type n = idx.length ();
-        octave_idx_type nr = nda.rows ();
-        for (octave_idx_type i = 0; i < n; i++)
-          {
-            jdx.xelem (i) = idx.xelem (i) / nr;
-            idx.xelem (i) %= nr;
-          }
-        iext = -1;
-        retval(1) = idx_vector (jdx, -1);
-      }
-      // Fall through!
-
-    case 1:
-    case 0:
-      retval(0) = idx_vector (idx, iext);
-      break;
+      res.xelem (place) = iter.interp_idx ();
     }
 
-  return retval;
-}
-
-template <typename T>
-octave_value_list
-find_nonzero_elem_idx (const Sparse<T>& v, int nargout,
-                       octave_idx_type n_to_find, int direction)
-{
-  octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
-
-  octave_idx_type nr = v.rows ();
-  octave_idx_type nc = v.cols ();
-  octave_idx_type nz = v.nnz ();
-
-  // Search in the default range.
-  octave_idx_type start_nc = -1;
-  octave_idx_type end_nc = -1;
-  octave_idx_type count;
-
-  // Search for the range to search
-  if (n_to_find < 0)
-    {
-      start_nc = 0;
-      end_nc = nc;
-      n_to_find = nz;
-      count = nz;
-    }
-  else if (direction > 0)
+    octave_value_list
+    get_list (void)
     {
-      for (octave_idx_type j = 0; j < nc; j++)
-        {
-          OCTAVE_QUIT;
-          if (v.cidx (j) == 0 && v.cidx (j+1) != 0)
-            start_nc = j;
-          if (v.cidx (j+1) >= n_to_find)
-            {
-              end_nc = j + 1;
-              break;
-            }
-        }
+      return octave_value_list (octave_value (res));
     }
-  else
+  };
+
+  template<typename T>
+  struct ffind_result<2, T>
+  {
+    ffind_result (void) { }
+    ffind_result (const dim_vector& nnz) : rescol (nnz), resrow (nnz) { }
+
+    Array<double> rescol;
+    Array<double> resrow;
+
+    template<typename iter_t>
+    void
+    add (octave_idx_type place, const iter_t& iter)
     {
-      for (octave_idx_type j = nc; j > 0; j--)
-        {
-          OCTAVE_QUIT;
-          if (v.cidx (j) == nz && v.cidx (j-1) != nz)
-            end_nc = j;
-          if (nz - v.cidx (j-1) >= n_to_find)
-            {
-              start_nc = j - 1;
-              break;
-            }
-        }
-    }
-
-  count = (n_to_find > v.cidx (end_nc) - v.cidx (start_nc) ?
-           v.cidx (end_nc) - v.cidx (start_nc) : n_to_find);
-
-  octave_idx_type result_nr;
-  octave_idx_type result_nc;
-
-  // Default case is to return a column vector, however, if the original
-  // argument was a row vector, then force return of a row vector.
-  if (nr == 1)
-    {
-      result_nr = 1;
-      result_nc = count;
-    }
-  else
-    {
-      result_nr = count;
-      result_nc = 1;
+      rescol.xelem (place) = to_interp_idx (iter.col ());
+      resrow.xelem (place) = to_interp_idx (iter.row ());
     }
 
-  Matrix idx (result_nr, result_nc);
-
-  Matrix i_idx (result_nr, result_nc);
-  Matrix j_idx (result_nr, result_nc);
-
-  Array<T> val (dim_vector (result_nr, result_nc));
-
-  if (count > 0)
+    octave_value_list
+    get_list (void)
     {
-      // Search for elements to return.  Only search the region where there
-      // are elements to be found using the count that we want to find.
-      for (octave_idx_type j = start_nc, cx = 0; j < end_nc; j++)
-        for (octave_idx_type i = v.cidx (j); i < v.cidx (j+1); i++)
-          {
-            OCTAVE_QUIT;
-            if (direction < 0 && i < nz - count)
-              continue;
-            i_idx(cx) = static_cast<double> (v.ridx (i) + 1);
-            j_idx(cx) = static_cast<double> (j + 1);
-            idx(cx) = j * nr + v.ridx (i) + 1;
-            val(cx) = v.data(i);
-            cx++;
-            if (cx == count)
-              break;
-          }
+      octave_value_list res (2);
+      res.xelem (0) = resrow;
+      res.xelem (1) = rescol;
+      return res;
     }
-  else
-    {
-      // No items found.  Fixup return dimensions for Matlab compatibility.
-      // The behavior to match is documented in Array.cc (Array<T>::find).
-      if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1))
-        {
-          idx.resize (0, 0);
+  };
 
-          i_idx.resize (0, 0);
-          j_idx.resize (0, 0);
-
-          val.resize (dim_vector (0, 0));
-        }
-    }
-
-  switch (nargout)
-    {
-    case 0:
-    case 1:
-      retval(0) = idx;
-      break;
+  template<typename T>
+  struct ffind_result<3, T>
+  {
+    ffind_result (void) { }
+    ffind_result (const dim_vector& nnz) :
+      rescol (nnz), resrow (nnz), elems (nnz) { }
+    Array<double> rescol;
+    Array<double> resrow;
+    Array<T> elems;
 
-    case 5:
-      retval(4) = nc;
-      // Fall through
-
-    case 4:
-      retval(3) = nr;
-      // Fall through
-
-    case 3:
-      retval(2) = val;
-      // Fall through!
-
-    case 2:
-      retval(1) = j_idx;
-      retval(0) = i_idx;
-      break;
-
-    default:
-      panic_impossible ();
-      break;
+    template<typename iter_t>
+    void
+    add (octave_idx_type place, const iter_t& iter)
+    {
+      rescol.xelem (place) = to_interp_idx (iter.col ());
+      resrow.xelem (place) = to_interp_idx (iter.row ());
+      elems.xelem (place) = iter.data ();
     }
 
-  return retval;
-}
-
-octave_value_list
-find_nonzero_elem_idx (const PermMatrix& v, int nargout,
-                       octave_idx_type n_to_find, int direction)
-{
-  // There are far fewer special cases to handle for a PermMatrix.
-  octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
-
-  octave_idx_type nr = v.rows ();
-  octave_idx_type nc = v.cols ();
-  octave_idx_type start_nc, count;
-
-  // Determine the range to search.
-  if (n_to_find < 0 || n_to_find >= nc)
-    {
-      start_nc = 0;
-      count = nc;
-    }
-  else if (direction > 0)
+    octave_value_list
+    get_list (void)
     {
-      start_nc = 0;
-      count = n_to_find;
-    }
-  else
-    {
-      start_nc = nc - n_to_find;
-      count = n_to_find;
-    }
-
-  Matrix idx (count, 1);
-  Matrix i_idx (count, 1);
-  Matrix j_idx (count, 1);
-  // Every value is 1.
-  Array<double> val (dim_vector (count, 1), 1.0);
-
-  if (count > 0)
-    {
-      const Array<octave_idx_type>& p = v.col_perm_vec ();
-      for (octave_idx_type k = 0; k < count; k++)
-        {
-          OCTAVE_QUIT;
-          const octave_idx_type j = start_nc + k;
-          const octave_idx_type i = p(j);
-          i_idx(k) = static_cast<double> (1+i);
-          j_idx(k) = static_cast<double> (1+j);
-          idx(k) = j * nc + i + 1;
-        }
+      octave_value_list res (3);
+      res.xelem (0) = resrow;
+      res.xelem (1) = rescol;
+      res.xelem (2) = elems;
+      return res;
     }
-  else
-    {
-      // FIXME: Is this case even possible?  A scalar permutation matrix seems
-      // to devolve to a scalar full matrix, at least from the Octave command
-      // line.  Perhaps this function could be called internally from C++ with
-      // such a matrix.
-      // No items found.  Fixup return dimensions for Matlab compatibility.
-      // The behavior to match is documented in Array.cc (Array<T>::find).
-      if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1))
-        {
-          idx.resize (0, 0);
+  };
 
-          i_idx.resize (0, 0);
-          j_idx.resize (0, 0);
+  // Calls the to_R template method in "find.h" which
+  // in turn will fill in resvec of type ffind_result (see above).
+  // This is generic enough to work for any matrix type M
+  template<direction dir, int nargout, typename M>
+  octave_value_list
+  call (const M& v, octave_idx_type n_to_find)
+  {
+    ffind_result<nargout, typename M::element_type> resvec;
+    dir_handler<dir> dirc;
+    resvec =
+            find_to_R<ffind_result<nargout, typename M::element_type> > (dirc,
+                    v, n_to_find);
 
-          val.resize (dim_vector (0, 0));
-        }
-    }
+    return resvec.get_list ();
+  }
 
-  switch (nargout)
-    {
-    case 0:
-    case 1:
-      retval(0) = idx;
-      break;
-
-    case 5:
-      retval(4) = nc;
-      // Fall through
-
-    case 4:
-      retval(3) = nc;
-      // Fall through
+  template<int nargout, typename M>
+  octave_value_list
+  dir_to_template (const M& v, octave_idx_type n_to_find, direction dir)
+  {
+    switch (dir)
+      {
+      case BACKWARD:
+        return call<BACKWARD, nargout> (v, n_to_find);
+      case FORWARD:
+        return call<FORWARD, nargout> (v, n_to_find);
+      default:
+        panic_impossible ();
+      }
+    return octave_value_list ();
+  }
 
-    case 3:
-      retval(2) = val;
-      // Fall through!
+  template<typename M>
+  octave_value_list
+  nargout_to_template (const M& v, int nargout, octave_idx_type n_to_find,
+                       direction dir)
+  {
+    switch (nargout)
+      {
+      case 1:
+        return dir_to_template<1> (v, n_to_find, dir);
+      case 2:
+        return dir_to_template<2> (v, n_to_find, dir);
+      case 3:
+        return dir_to_template<3> (v, n_to_find, dir);
+      default:
+        panic_impossible (); // Checked by *** in Ffind
+      }
+    return octave_value_list ();
+  }
 
-    case 2:
-      retval(1) = j_idx;
-      retval(0) = i_idx;
-      break;
+  template<typename M>
+  octave_value_list
+  find_templated (const M& v, int nargout, octave_idx_type n_to_find,
+                  direction dir)
+  {
+    return nargout_to_template (v, nargout, n_to_find, dir);
+  }
 
-    default:
-      panic_impossible ();
-      break;
-    }
-
-  return retval;
 }
 
 DEFUN (find, args, nargout,
@@ -399,6 +265,12 @@
       return retval;
     }
 
+  // ***
+  if (nargout < 1)
+    nargout = 1;
+  else if (nargout > 3)
+    nargout = 3;
+
   // Setup the default options.
   octave_idx_type n_to_find = -1;
   if (nargin > 1)
@@ -414,23 +286,22 @@
         n_to_find = val;
     }
 
-  // Direction to do the searching (1 == forward, -1 == reverse).
-  int direction = 1;
+  // Direction to do the searching.
+  direction dir = FORWARD;
   if (nargin > 2)
     {
-      direction = 0;
-
       std::string s_arg = args(2).string_value ();
 
-      if (! error_state)
+      if (error_state)
         {
-          if (s_arg == "first")
-            direction = 1;
-          else if (s_arg == "last")
-            direction = -1;
+          error ("find: DIRECTION must be \"first\" or \"last\"");
+          return retval;
         }
-
-      if (direction == 0)
+      if (s_arg == "first")
+        dir = FORWARD;
+      else if (s_arg == "last")
+        dir = BACKWARD;
+      else
         {
           error ("find: DIRECTION must be \"first\" or \"last\"");
           return retval;
@@ -446,10 +317,9 @@
           SparseBoolMatrix v = arg.sparse_bool_matrix_value ();
 
           if (! error_state)
-            retval = find_nonzero_elem_idx (v, nargout,
-                                            n_to_find, direction);
+            retval = find::find_templated (v, nargout, n_to_find, dir);
         }
-      else if (nargout <= 1 && n_to_find == -1 && direction == 1)
+      else if (nargout <= 1 && n_to_find == -1)
         {
           // This case is equivalent to extracting indices from a logical
           // matrix. Try to reuse the possibly cached index vector.
@@ -460,20 +330,18 @@
           boolNDArray v = arg.bool_array_value ();
 
           if (! error_state)
-            retval = find_nonzero_elem_idx (v, nargout,
-                                            n_to_find, direction);
+            retval = find::find_templated (v, nargout, n_to_find, dir);
         }
     }
   else if (arg.is_integer_type ())
     {
-#define DO_INT_BRANCH(INTT) \
-      else if (arg.is_ ## INTT ## _type ()) \
-        { \
-          INTT ## NDArray v = arg.INTT ## _array_value (); \
-          \
-          if (! error_state) \
-            retval = find_nonzero_elem_idx (v, nargout, \
-                                            n_to_find, direction);\
+#define DO_INT_BRANCH(INTT)                                               \
+      else if (arg.is_ ## INTT ## _type ())                               \
+        {                                                                 \
+          INTT ## NDArray v = arg.INTT ## _array_value ();                \
+                                                                          \
+            if (! error_state)                                            \
+              retval = find::find_templated (v, nargout, n_to_find, dir); \
         }
 
       if (false)
@@ -496,16 +364,14 @@
           SparseMatrix v = arg.sparse_matrix_value ();
 
           if (! error_state)
-            retval = find_nonzero_elem_idx (v, nargout,
-                                            n_to_find, direction);
+            retval = find::find_templated (v, nargout, n_to_find, dir);
         }
       else if (arg.is_complex_type ())
         {
           SparseComplexMatrix v = arg.sparse_complex_matrix_value ();
 
           if (! error_state)
-            retval = find_nonzero_elem_idx (v, nargout,
-                                            n_to_find, direction);
+            retval = find::find_templated (v, nargout, n_to_find, dir);
         }
       else
         gripe_wrong_type_arg ("find", arg);
@@ -515,14 +381,14 @@
       PermMatrix P = arg.perm_matrix_value ();
 
       if (! error_state)
-        retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction);
+        retval = find::find_templated (P, nargout, n_to_find, dir);
     }
   else if (arg.is_string ())
     {
       charNDArray chnda = arg.char_array_value ();
 
       if (! error_state)
-        retval = find_nonzero_elem_idx (chnda, nargout, n_to_find, direction);
+        retval = find::find_templated (chnda, nargout, n_to_find, dir);
     }
   else if (arg.is_single_type ())
     {
@@ -531,16 +397,14 @@
           FloatNDArray nda = arg.float_array_value ();
 
           if (! error_state)
-            retval = find_nonzero_elem_idx (nda, nargout, n_to_find,
-                                            direction);
+            retval = find::find_templated (nda, nargout, n_to_find, dir);
         }
       else if (arg.is_complex_type ())
         {
           FloatComplexNDArray cnda = arg.float_complex_array_value ();
 
           if (! error_state)
-            retval = find_nonzero_elem_idx (cnda, nargout, n_to_find,
-                                            direction);
+            retval = find::find_templated (cnda, nargout, n_to_find, dir);
         }
     }
   else if (arg.is_real_type ())
@@ -548,14 +412,14 @@
       NDArray nda = arg.array_value ();
 
       if (! error_state)
-        retval = find_nonzero_elem_idx (nda, nargout, n_to_find, direction);
+        retval = find::find_templated (nda, nargout, n_to_find, dir);
     }
   else if (arg.is_complex_type ())
     {
       ComplexNDArray cnda = arg.complex_array_value ();
 
       if (! error_state)
-        retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, direction);
+        retval = find::find_templated (cnda, nargout, n_to_find, dir);
     }
   else
     gripe_wrong_type_arg ("find", arg);
@@ -611,5 +475,11 @@
 %!assert (find ([2 0 1 0 5 0], Inf), [1, 3, 5])
 %!assert (find ([2 0 1 0 5 0], Inf, "last"), [1, 3, 5])
 
+%!test
+%! x = sparse(100000, 30000);
+%! x(end, end) = 1;
+%! i = find(x);
+%! assert (i == 3e09);
+
 %!error find ()
 */