changeset 8955:6d3fcbf89267

Add an override to Octave's find() for permutation matrices. Because of find()'s count-limiting and direction arguments, this is slightly more complicated than just copying the permutation vector. I suspect this is a common operation for people who don't know about the 'vector' option to lu().
author Jason Riedy <jason@acm.org>
date Tue, 10 Mar 2009 21:54:39 -0400
parents 97c84c4c2247
children d91fa4b20bbb
files src/ChangeLog src/DLD-FUNCTIONS/find.cc
diffstat 2 files changed, 142 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Tue Mar 10 21:54:34 2009 -0400
+++ b/src/ChangeLog	Tue Mar 10 21:54:39 2009 -0400
@@ -1,3 +1,10 @@
+2009-03-10  Jason Riedy  <jason@acm.org>
+
+	* DLD-FUNCTIONS/find.cc (find_nonzero_elem_idx): New override
+	for find on PermMatrix.
+	(find): Add a branch testing arg.is_perm_matrix () and calling the
+	above override.
+
 2009-03-10  John W. Eaton  <jwe@octave.org>
 
 	* c-file-ptr-stream.cc, dynamic-ld.cc, error.cc, lex.l, pager.cc,
--- a/src/DLD-FUNCTIONS/find.cc	Tue Mar 10 21:54:34 2009 -0400
+++ b/src/DLD-FUNCTIONS/find.cc	Tue Mar 10 21:54:39 2009 -0400
@@ -333,6 +333,117 @@
 template octave_value_list find_nonzero_elem_idx (const Sparse<Complex>&, int,
 						  octave_idx_type, int);
 
+octave_value_list
+find_nonzero_elem_idx (const PermMatrix& v, int nargout, 
+		       octave_idx_type n_to_find, int direction)
+{
+  // There are far fewer special cases to handle for a PermMatrix.
+  octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
+
+  octave_idx_type nc = v.cols();
+  octave_idx_type start_nc, end_nc, count;
+ 
+  // Determine the range to search.
+  if (n_to_find < 0 || n_to_find >= nc)
+    {
+      start_nc = 0;
+      end_nc = nc;
+      n_to_find = nc;
+      count = nc;
+    }
+  else if (direction > 0)
+    {
+      start_nc = 0;
+      end_nc = n_to_find;
+      count = n_to_find;
+    }
+  else
+    {
+      start_nc = nc - n_to_find;
+      end_nc = nc;
+      count = n_to_find;
+    }
+
+  bool scalar_arg = (v.rows () == 1 && v.cols () == 1);
+
+  Matrix idx (count, 1);
+  Matrix i_idx (count, 1);
+  Matrix j_idx (count, 1);
+  // Every value is 1.
+  ArrayN<double> val (dim_vector (count, 1), 1.0);
+
+  if (count > 0)
+    {
+      const octave_idx_type* p = v.data ();
+      if (v.is_col_perm ())
+        {
+          for (octave_idx_type k = 0; k < count; k++) 
+            {
+              OCTAVE_QUIT;
+              const octave_idx_type j = start_nc + k;
+              const octave_idx_type i = p[j];
+              i_idx(k) = static_cast<double> (1+i);
+              j_idx(k) = static_cast<double> (1+j);
+              idx(k) = j * nc + i + 1;
+            }
+        }
+      else
+        {
+          for (octave_idx_type k = 0; k < count; k++) 
+            {
+              OCTAVE_QUIT;
+              const octave_idx_type i = start_nc + k;
+              const octave_idx_type j = p[i];
+              // Scatter into the index arrays according to
+              // j adjusted by the start point.
+              const octave_idx_type koff = j - start_nc;
+              i_idx(koff) = static_cast<double> (1+i);
+              j_idx(koff) = static_cast<double> (1+j);
+              idx(koff) = j * nc + i + 1;
+            }
+        }
+    }
+  else if (scalar_arg)
+    {
+      // Same odd compatibility case as the other overrides.
+      idx.resize (0, 0);
+      i_idx.resize (0, 0);
+      j_idx.resize (0, 0);
+      val.resize (dim_vector (0, 0));
+    }
+
+  switch (nargout)
+    {
+    case 0:
+    case 1:
+      retval(0) = idx;
+      break;
+
+    case 5:
+      retval(4) = nc;
+      // Fall through
+
+    case 4:
+      retval(3) = nc;
+      // Fall through
+
+    case 3:
+      retval(2) = val;
+      // Fall through!
+
+    case 2:
+      retval(1) = j_idx;
+      retval(0) = i_idx;
+      break;
+
+    default:
+      panic_impossible ();
+      break;
+    }
+
+  return retval;
+}
+
 DEFUN_DLD (find, args, nargout,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {} find (@var{x})\n\
@@ -462,6 +573,12 @@
       else 
 	gripe_wrong_type_arg ("find", arg);
     }
+  else if (arg.is_perm_matrix ()) {
+    PermMatrix P = arg.perm_matrix_value ();
+
+    if (! error_state)
+      retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction);
+  }
   else
     {
       if (arg.is_single_type ())
@@ -542,6 +659,24 @@
 %! assert(j, [1; 2; 3]);
 %! assert(v, single([-1; 3; 2]));
 
+%!test
+%! pcol = [5 1 4 3 2];
+%! P = eye (5) (:, pcol);
+%! [i, j, v] = find (P);
+%! [ifull, jfull, vfull] = find (full (P));
+%! assert (i, ifull);
+%! assert (j, jfull);
+%! assert (all (v == 1));
+
+%!test
+%! prow = [5 1 4 3 2];
+%! P = eye (5) (prow, :);
+%! [i, j, v] = find (P);
+%! [ifull, jfull, vfull] = find (full (P));
+%! assert (i, ifull);
+%! assert (j, jfull);
+%! assert (all (v == 1));
+
 %!error <Invalid call to find.*> find ();
 
  */