changeset 9028:e67dc11ed6e8

use Array<T>::find in find
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 26 Mar 2009 13:50:46 +0100
parents 9a46ba093db4
children 2df28ad88b0e
files src/ChangeLog src/DLD-FUNCTIONS/find.cc
diffstat 2 files changed, 106 insertions(+), 130 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Thu Mar 26 13:20:05 2009 +0100
+++ b/src/ChangeLog	Thu Mar 26 13:50:46 2009 +0100
@@ -1,3 +1,12 @@
+2009-03-26  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/find.cc 
+	(find_nonzero_elem_idx (const Array<T>&, ...)): Simplify.
+	Instantiate for bool and octave_int types.
+	(find_nonzero_elem_idx (const Sparse<T>&, ...)): 
+	Instantiate for bool.
+	(Ffind): Handle bool and octave_int cases.
+
 2009-03-25  John W. Eaton  <jwe@octave.org>
 
 	* version.h (OCTAVE_VERSION): Now 3.1.55+.
--- a/src/DLD-FUNCTIONS/find.cc	Thu Mar 26 13:20:05 2009 +0100
+++ b/src/DLD-FUNCTIONS/find.cc	Thu Mar 26 13:50:46 2009 +0100
@@ -43,155 +43,75 @@
 {
   octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
 
-  octave_idx_type count = 0;
-
-  octave_idx_type nel = nda.nelem ();
-
-  // Set the starting element to the correct value based on the
-  // direction to search.
-  octave_idx_type k = 0;
-  if (direction == -1)
-    k = nel - 1;
-
-  // Search in the default range.
-  octave_idx_type start_el = -1;
-  octave_idx_type end_el = -1;
-
-  // Search for the number of elements to return.
-  while (k < nel && k > -1 && n_to_find != count)
-    {
-      OCTAVE_QUIT;
-
-      if (nda(k) != T ())
-	{
-	  end_el = k;
-	  if (start_el == -1)
-	    start_el = k;
-	  count++;
-	}
-      k = k + direction;
-    }
-
-  // Reverse the range if we're looking backward.
-  if (direction == -1)
-    {
-      octave_idx_type tmp_el = start_el;
-      start_el = end_el;
-      end_el = tmp_el;
-    }
-  // Fix an off by one error.
-  end_el++;
-
-  // If the original argument was a row vector, force a row vector of
-  // the overall indices to be returned.  But see below for scalar
-  // case...
-
-  octave_idx_type result_nr = count;
-  octave_idx_type result_nc = 1;
-
-  bool column_vector_arg = false;
-  bool scalar_arg = false;
-
-  if (nda.ndims () == 2)
-    {
-      octave_idx_type nr = nda.rows ();
-      octave_idx_type nc = nda.columns ();
+  Array<octave_idx_type> idx;
+  if (n_to_find >= 0)
+    idx = nda.find (n_to_find, direction == -1);
+  else
+    idx = nda.find ();
 
-      if (nr == 1)
-	{
-	  result_nr = 1;
-	  result_nc = count;
-
-	  scalar_arg = (nc == 1);
-	}
-      else if (nc == 1)
-	column_vector_arg = true;
-    }
-
-  Matrix idx (result_nr, result_nc);
-
-  Matrix i_idx (result_nr, result_nc);
-  Matrix j_idx (result_nr, result_nc);
-
-  ArrayN<T> val (dim_vector (result_nr, result_nc));
-
-  if (count > 0)
-    {
-      count = 0;
-
-      octave_idx_type nr = nda.rows ();
-
-      octave_idx_type i = 0;
-
-      // Search for elements to return.  Only search the region where
-      // there are elements to be found using the count that we want
-      // to find.
+  // Fixup idx dimensions, for Matlab compatibility.
+  // find(zeros(0,0)) -> zeros(0,0)
+  // find(zeros(1,0)) -> zeros(1,0)
+  // find(zeros(0,1)) -> zeros(0,1)
+  // find(zeros(0,X)) -> zeros(0,1)
+  // find(zeros(1,1)) -> zeros(0,0) !!!! WHY?
+  // find(zeros(0,1,0)) -> zeros(0,0)
+  // find(zeros(0,1,0,1)) -> zeros(0,0) etc
+  // FIXME: I don't believe this is right. Matlab seems to violate its own docs
+  // here, because a scalar *is* a row vector.
 
-      // For compatibility, all N-d arrays are handled as if they are
-      // 2-d, with the number of columns equal to "prod (dims (2:end))".
-
-      for (k = start_el; k < end_el; k++)
-	{
-	  OCTAVE_QUIT;
-
-	  if (nda(k) != T ())
-	    {
-	      idx(count) = k + 1;
-
-	      octave_idx_type xr = k % nr;
-	      i_idx(count) = xr + 1;
-	      j_idx(count) = (k - xr) / nr + 1;
-
-	      val(count) = nda(k);
-
-	      count++;
-	    }
-
-	  i++;
-	}
-    }
-  else if (scalar_arg || (nda.rows () == 0 && ! column_vector_arg))
-    {
-      idx.resize (0, 0);
-
-      i_idx.resize (0, 0);
-      j_idx.resize (0, 0);
-
-      val.resize (dim_vector (0, 0));
-    }
+  if ((nda.numel () == 1 && idx.is_empty ())
+      || (nda.rows () == 0 && nda.dims ().numel (1) == 0))
+    idx = idx.reshape (dim_vector (0, 0));
+  else if (nda.rows () == 1 && nda.ndims () == 2)
+    idx = idx.reshape (dim_vector (1, idx.length ()));
 
   switch (nargout)
     {
     default:
     case 3:
-      retval(2) = val;
+      retval(2) = ArrayN<T> (nda.index (idx_vector (idx)));
       // Fall through!
 
     case 2:
-      retval(1) = j_idx;
-      retval(0) = i_idx;
-      break;
+      {
+        Array<octave_idx_type> jdx (idx.dims ());
+        octave_idx_type n = idx.length (), nr = nda.rows ();
+        for (octave_idx_type i = 0; i < n; i++)
+          {
+            jdx.xelem (i) = idx.xelem (i) / nr;
+            idx.xelem (i) %= nr;
+          }
+        retval(1) = NDArray (jdx, true);
+      }
+      // Fall through!
 
     case 1:
     case 0:
-      retval(0) = idx;
+      retval(0) = NDArray (idx, true);
       break;
     }
 
   return retval;
 }
 
-template octave_value_list find_nonzero_elem_idx (const Array<double>&, int,
-						  octave_idx_type, int);
-
-template octave_value_list find_nonzero_elem_idx (const Array<Complex>&, int,
-						  octave_idx_type, int);
+#define INSTANTIATE_FIND_ARRAY(T) \
+template octave_value_list find_nonzero_elem_idx (const Array<T>&, int, \
+						  octave_idx_type, int)
 
-template octave_value_list find_nonzero_elem_idx (const Array<float>&, int,
-						  octave_idx_type, int);
-
-template octave_value_list find_nonzero_elem_idx (const Array<FloatComplex>&,
-						  int, octave_idx_type, int);
+INSTANTIATE_FIND_ARRAY(double);
+INSTANTIATE_FIND_ARRAY(float);
+INSTANTIATE_FIND_ARRAY(Complex);
+INSTANTIATE_FIND_ARRAY(FloatComplex);
+INSTANTIATE_FIND_ARRAY(bool);
+INSTANTIATE_FIND_ARRAY(octave_int8);
+INSTANTIATE_FIND_ARRAY(octave_int16);
+INSTANTIATE_FIND_ARRAY(octave_int32);
+INSTANTIATE_FIND_ARRAY(octave_int64);
+INSTANTIATE_FIND_ARRAY(octave_uint8);
+INSTANTIATE_FIND_ARRAY(octave_uint16);
+INSTANTIATE_FIND_ARRAY(octave_uint32);
+INSTANTIATE_FIND_ARRAY(octave_uint64);
 
 template <typename T>
 octave_value_list
@@ -342,6 +262,9 @@
 template octave_value_list find_nonzero_elem_idx (const Sparse<Complex>&, int,
 						  octave_idx_type, int);
 
+template octave_value_list find_nonzero_elem_idx (const Sparse<bool>&, int,
+						  octave_idx_type, int);
+
 octave_value_list
 find_nonzero_elem_idx (const PermMatrix& v, int nargout, 
 		       octave_idx_type n_to_find, int direction)
@@ -561,7 +484,51 @@
 
   octave_value arg = args(0);
 
-  if (arg.is_sparse_type ())
+  if (arg.is_bool_type ())
+    {
+      if (arg.is_sparse_type ())
+        {
+	  SparseBoolMatrix v = arg.sparse_bool_matrix_value ();
+
+	  if (! error_state)
+	    retval = find_nonzero_elem_idx (v, nargout, 
+					    n_to_find, direction);
+        }
+      else
+        {
+          boolNDArray v = arg.bool_array_value ();
+
+	  if (! error_state)
+	    retval = find_nonzero_elem_idx (v, nargout, 
+					    n_to_find, direction);
+        }
+    }
+  else if (arg.is_integer_type ())
+    {
+#define DO_INT_BRANCH(INTT) \
+      else if (arg.is_ ## INTT ## _type ()) \
+        { \
+          INTT ## NDArray v = arg.INTT ## _array_value (); \
+          \
+	  if (! error_state) \
+	    retval = find_nonzero_elem_idx (v, nargout, \
+					    n_to_find, direction);\
+        }
+
+      if (false)
+        ;
+      DO_INT_BRANCH (int8)
+      DO_INT_BRANCH (int16)
+      DO_INT_BRANCH (int32)
+      DO_INT_BRANCH (int64)
+      DO_INT_BRANCH (uint8)
+      DO_INT_BRANCH (uint16)
+      DO_INT_BRANCH (uint32)
+      DO_INT_BRANCH (uint64)
+      else
+        panic_impossible ();
+    }
+  else if (arg.is_sparse_type ())
     {
       if (arg.is_real_type ())
 	{