changeset 19010:3fb030666878 draft default tip dspies

Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior
author David Spies <dnspies@gmail.com>
date Fri, 25 Jul 2014 13:39:31 -0600
parents 8d47ce2053f2
children
files libinterp/octave-value/ov-base-diag.cc libinterp/octave-value/ov-base-mat.cc libinterp/octave-value/ov-base-sparse.cc libinterp/octave-value/ov-perm.cc liboctave/array/Array.h liboctave/array/dim-vector.cc liboctave/util/array-iter-decl.h liboctave/util/direction.h liboctave/util/dv-utils.cc liboctave/util/dv-utils.h liboctave/util/idx-bounds.h liboctave/util/interp-idx.h liboctave/util/is-zero.h liboctave/util/logical-index.h liboctave/util/module.mk liboctave/util/nz-iterators.h test/logical-index.tst
diffstat 17 files changed, 832 insertions(+), 119 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/octave-value/ov-base-diag.cc	Mon Jul 14 13:07:59 2014 -0600
+++ b/libinterp/octave-value/ov-base-diag.cc	Fri Jul 25 13:39:31 2014 -0600
@@ -37,6 +37,7 @@
 #include "gripes.h"
 #include "oct-stream.h"
 #include "ops.h"
+#include "logical-index.h"
 
 #include "ls-oct-ascii.h"
 
@@ -125,6 +126,8 @@
             retval = to_dense ().do_index_op (idx, resize_ok);
         }
     }
+  else if (idx.length () == 1 && idx(0).is_bool_type ())
+    retval = call_bool_index (matrix, idx(0));
   else
     retval = to_dense ().do_index_op (idx, resize_ok);
 
--- a/libinterp/octave-value/ov-base-mat.cc	Mon Jul 14 13:07:59 2014 -0600
+++ b/libinterp/octave-value/ov-base-mat.cc	Fri Jul 25 13:39:31 2014 -0600
@@ -34,6 +34,7 @@
 #include "ov-base-mat.h"
 #include "ov-base-scalar.h"
 #include "pr-output.h"
+#include "logical-index.h"
 
 template <class MT>
 octave_value
@@ -146,15 +147,20 @@
 
     case 1:
       {
-        idx_vector i = idx (0).index_vector ();
-
-        if (! error_state)
+        if (idx(0).is_bool_type ())
+          retval = call_bool_index (matrix, idx(0));
+        else
           {
-            // optimize single scalar index.
-            if (! resize_ok && i.is_scalar ())
-              retval = cmatrix.checkelem (i(0));
-            else
-              retval = MT (matrix.index (i, resize_ok));
+            idx_vector i = idx (0).index_vector ();
+
+            if (! error_state)
+              {
+                // optimize single scalar index.
+                if (! resize_ok && i.is_scalar ())
+                  retval = cmatrix.checkelem (i(0));
+                else
+                  retval = MT (matrix.index (i, resize_ok));
+              }
           }
       }
       break;
--- a/libinterp/octave-value/ov-base-sparse.cc	Mon Jul 14 13:07:59 2014 -0600
+++ b/libinterp/octave-value/ov-base-sparse.cc	Fri Jul 25 13:39:31 2014 -0600
@@ -29,6 +29,7 @@
 #include <iomanip>
 #include <iostream>
 
+#include "logical-index.h"
 #include "oct-obj.h"
 #include "ov-base.h"
 #include "quit.h"
@@ -61,10 +62,16 @@
 
     case 1:
       {
-        idx_vector i = idx (0).index_vector ();
+        const octave_value& v = idx(0);
+        if (v.is_bool_type ())
+          retval = call_bool_index (matrix, v);
+        else
+          {
+            idx_vector i = v.index_vector ();
 
-        if (! error_state)
-          retval = octave_value (matrix.index (i, resize_ok));
+            if (!error_state)
+              retval = octave_value (matrix.index (i, resize_ok));
+          }
       }
       break;
 
--- a/libinterp/octave-value/ov-perm.cc	Mon Jul 14 13:07:59 2014 -0600
+++ b/libinterp/octave-value/ov-perm.cc	Fri Jul 25 13:39:31 2014 -0600
@@ -35,6 +35,7 @@
 #include "gripes.h"
 #include "ops.h"
 #include "pr-output.h"
+#include "logical-index.h"
 
 #include "ls-oct-ascii.h"
 
@@ -117,6 +118,8 @@
         {
           retval = matrix.checkelem (idx0(0), idx1(0));
         }
+      else if (nidx == 1 && idx(0).is_bool_type ())
+        retval = call_bool_index (matrix, idx(0));
       else
         retval = to_dense ().do_index_op (idx, resize_ok);
     }
--- a/liboctave/array/Array.h	Mon Jul 14 13:07:59 2014 -0600
+++ b/liboctave/array/Array.h	Fri Jul 25 13:39:31 2014 -0600
@@ -41,11 +41,7 @@
 #include "quit.h"
 #include "oct-mem.h"
 #include "oct-refcount.h"
-
-//Forward declaration for array_iterator,
-//the nonzero-iterator type for Array (in nz_iterator.h)
-template<typename T>
-class array_iterator;
+#include "array-iter-decl.h"
 
 // One dimensional array class.  Handles the reference counting for
 // all the derived classes.
--- a/liboctave/array/dim-vector.cc	Mon Jul 14 13:07:59 2014 -0600
+++ b/liboctave/array/dim-vector.cc	Fri Jul 25 13:39:31 2014 -0600
@@ -27,6 +27,7 @@
 
 #include <iostream>
 
+#include "idx-bounds.h"
 #include "dim-vector.h"
 
 // The maximum allowed value for a dimension extent. This will normally be a
@@ -36,7 +37,7 @@
 octave_idx_type
 dim_vector::dim_max (void)
 {
-  return std::numeric_limits<octave_idx_type>::max () - 1;
+  return max_idx;
 }
 
 void
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/liboctave/util/array-iter-decl.h	Fri Jul 25 13:39:31 2014 -0600
@@ -0,0 +1,33 @@
+/*
+
+Copyright (C) 2014 David Spies
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+
+#if !defined (octave_array_iter_decl_h)
+#define octave_array_iter_decl_h 1
+
+//Forward declaration for array_iterator,
+//the nonzero-iterator type for Array (in nz_iterator.h)
+//Must be declared in its own file because one of the template parameters
+//has a default value.
+template<typename T, bool userc = true>
+class array_iterator;
+
+#endif
--- a/liboctave/util/direction.h	Mon Jul 14 13:07:59 2014 -0600
+++ b/liboctave/util/direction.h	Fri Jul 25 13:39:31 2014 -0600
@@ -73,4 +73,7 @@
   }
 };
 
+extern const dir_handler<FORWARD> fdirc;
+extern const dir_handler<BACKWARD> bdirc;
+
 #endif
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/liboctave/util/dv-utils.cc	Fri Jul 25 13:39:31 2014 -0600
@@ -0,0 +1,85 @@
+/*
+
+Copyright (C) 2014 David Spies
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+
+#ifdef HAVE_CONFIG_H
+#include <config.h>
+#endif
+
+#include "dim-vector.h"
+#include <cassert>
+
+bool
+dv_is_extended_vector (const dim_vector& dv)
+{
+  bool found = false;
+  for (int i = 0; i < dv.length (); i++)
+    {
+      if (dv(i) != 1)
+        {
+          if (found)
+            return false;
+          else
+            found = true;
+        }
+    }
+  return found;
+}
+
+bool
+dv_is_scalar (const dim_vector& dv)
+{
+  for (int i = 0; i < dv.length (); i++)
+    {
+      if (dv(i) != 1)
+        return false;
+    }
+  return true;
+}
+
+bool
+dv_is_row (const dim_vector& dv)
+{
+  return dv_is_extended_vector (dv) && dv(1) != 1;
+}
+
+//Returns the dimension along which a vector is a nonsingleton
+int
+dv_vector_dimension (const dim_vector& dv)
+{
+  assert(dv_is_extended_vector (dv));
+  for (int i = 0; i < dv.length (); i++)
+    {
+      if (dv(i) != 1)
+        return i;
+    }
+  liboctave_fatal ("No non-1 dimension");
+}
+
+dim_vector
+dv_match_vector (const dim_vector& dv, octave_idx_type numel)
+{
+  if(numel == 1)
+    return dim_vector (1, 1);
+  dim_vector res = dv;
+  res (dv_vector_dimension (dv)) = numel;
+  return res;
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/liboctave/util/dv-utils.h	Fri Jul 25 13:39:31 2014 -0600
@@ -0,0 +1,40 @@
+/*
+
+Copyright (C) 2014 David Spies
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+
+#if !defined (octave_dv_utils_h)
+#define octave_dv_utils_h 1
+
+class dim_vector;
+
+//Returns true if the argument has exactly one nonsingleton dimension
+bool dv_is_extended_vector(const dim_vector& dv);
+
+//Returns true if all of dv's dimension-lengths are 1
+bool dv_is_scalar(const dim_vector& dv);
+
+//Returns true if the argument's second dimension is the only nonsingleton
+bool dv_is_row(const dim_vector& dv);
+
+//Creates a vector of length numel whose nonsingleton dimension is the same as dv
+dim_vector dv_match_vector(const dim_vector& dv, octave_idx_type numel);
+
+#endif
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/liboctave/util/idx-bounds.h	Fri Jul 25 13:39:31 2014 -0600
@@ -0,0 +1,53 @@
+/*
+
+Copyright (C) 2014 David Spies
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+
+#if !defined (octave_idx_bounds_h)
+#define octave_idx_bounds_h 1
+
+#include <limits>
+
+//The maximum value for octave_idx_type (and the maximum possible size of an
+//Array)
+const octave_idx_type idx_type_max_value =
+    std::numeric_limits<octave_idx_type>::max ();
+
+//The maximum index an array can have (idx_type_max_value - 1)
+const octave_idx_type max_idx = idx_type_max_value - 1;
+
+//Determines if a row-col pair will overflow octave_idx_type if translated into
+//a linear index
+inline bool
+row_col_to_idx_overflows (octave_idx_type row, octave_idx_type col,
+                          octave_idx_type height)
+{
+  return row > (max_idx - col) / height;
+}
+
+//Determines if a height-width pair will overflow octave_idx_type when nelem is
+//extracted
+inline bool
+height_width_to_nelem_overflows (octave_idx_type height, octave_idx_type width)
+{
+  return width > idx_type_max_value / height;
+}
+
+#endif
--- a/liboctave/util/interp-idx.h	Mon Jul 14 13:07:59 2014 -0600
+++ b/liboctave/util/interp-idx.h	Fri Jul 25 13:39:31 2014 -0600
@@ -22,6 +22,9 @@
 #if !defined (octave_interp_idx_h)
 #define octave_interp_idx_h 1
 
+#include "idx-bounds.h"
+#include "Array-util.h"
+
 // Simple method for converting between C++ octave_idx_type and
 // Octave data-types (convert to double and add one).
 inline double
@@ -44,4 +47,13 @@
   return col * static_cast<double> (dims(0)) + row + 1;
 }
 
+//Converts a row-column pair to a "flat" linear index.
+//Asserts that its result won't overflow octave_idx_type
+inline octave_idx_type
+to_flat_idx (octave_idx_type row, octave_idx_type col, const dim_vector& dims)
+{
+  assert(!row_col_to_idx_overflows (row, col, dims(0)));
+  return compute_index (row, col, dims, false);
+}
+
 #endif
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/liboctave/util/is-zero.h	Fri Jul 25 13:39:31 2014 -0600
@@ -0,0 +1,34 @@
+/*
+
+Copyright (C) 2014 David Spies
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+#if !defined (octave_is_zero_h)
+#define octave_is_zero_h
+
+// A generic method for checking if some element of a matrix with
+// element type T is zero.
+template<typename T>
+inline bool
+is_zero (T t)
+{
+  return t == static_cast<T> (0);
+}
+
+#endif
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/liboctave/util/logical-index.h	Fri Jul 25 13:39:31 2014 -0600
@@ -0,0 +1,197 @@
+/*
+
+Copyright (C) 2014 David Spies
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+#if !defined (octave_logical_index_h)
+#define octave_logical_index_h 1
+
+#include "Array.h"
+#include "dim-vector.h"
+#include "dv-utils.h"
+#include "direction.h"
+#include "dispatch.h"
+#include "nz-iterators.h"
+#include "ov.h"
+
+//Reshapes idx to have dims_rows rows (adding extra zeros if necessary)
+template<typename IM>
+Sparse<bool>
+partial_sparse_reshape (const IM& idx, octave_idx_type dims_rows)
+{
+  typename IM::iter_type idx_iter (idx);
+
+  dim_vector res_dims = dim_vector (idx.nnz (), 1);
+
+  Array<octave_idx_type> rows (res_dims);
+  Array<octave_idx_type> cols (res_dims);
+
+  const octave_idx_type idx_rows = idx.rows ();
+
+  octave_idx_type i;
+  octave_idx_type col_start_col = 0;
+  octave_idx_type col_start_row = 0;
+  octave_idx_type idx_col = 0;
+  for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc);
+      ++i, idx_iter.step (fdirc))
+    {
+      for (; idx_col < idx_iter.col (); ++idx_col)
+        {
+          octave_idx_type next_row = col_start_row + idx_rows;
+          col_start_col += next_row / dims_rows;
+          col_start_row = next_row % dims_rows;
+        }
+      octave_idx_type next_row = col_start_row + idx_iter.row ();
+      cols.xelem (i) = col_start_col + next_row / dims_rows;
+      rows.xelem (i) = next_row % dims_rows;
+    }
+  Array<bool> trueScalar (dim_vector (1, 1), true);
+  return Sparse<bool> (trueScalar, idx_vector (rows), idx_vector (cols));
+}
+
+// Given matrix mat and logical matrix idx, takes mat(idx) and returns the
+// result as an Array. "full" indicates if mat is full (ie can use linear index)
+// or not. IterM is the nonzero-iterator type for M (the type of mat)
+// if !full, mat and idx must have the same height
+template<bool full, typename IterM, typename M, typename IM>
+Array<typename M::element_type>
+take_bool_with_index (const M& mat, const IM& idx)
+{
+  typedef typename M::element_type ELT_T;
+
+  // After testing over 150 edge-cases in Matlab, this seems to be the rule for
+  // logical indexing return-value dimensions
+  const dim_vector idx_dims = idx.dims ();
+  const dim_vector mat_dims = mat.dims ();
+  octave_idx_type idx_nnz = idx.nnz ();
+  dim_vector res_dims;
+  if (idx_dims(0) == 0 && idx_dims(1) == 0)
+    res_dims = dim_vector (0, 0);
+  else if (dv_is_scalar (idx_dims) && idx_nnz == 0)
+    res_dims = dim_vector (0, 0);
+  else if (dv_is_extended_vector (mat_dims))
+    res_dims = dv_match_vector (mat_dims, idx_nnz);
+  else
+    {
+      if (dv_is_row (idx_dims))
+        res_dims = dim_vector (1, idx_nnz);
+      else
+        res_dims = dim_vector (idx_nnz, 1);
+    }
+
+  IterM mat_iter (mat);
+  typename IM::iter_type idx_iter (idx);
+
+  Array<ELT_T> res (res_dims);
+
+  octave_idx_type i;
+  for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc);
+      ++i, idx_iter.step (fdirc))
+    {
+      bool nz;
+      if (full)
+        nz = mat_iter.skip_ahead (idx_iter.flat_idx ());
+      else
+        nz = mat_iter.skip_ahead (idx_iter.row (), idx_iter.col ());
+      if (nz)
+        res.xelem (i) = mat_iter.data ();
+      else
+        res.xelem (i) = static_cast<ELT_T> (0);
+    }
+  return res;
+}
+
+// Determine whether we need to reshape idx before calling take_bool_with_index
+// to ensure mat and idx have the same height
+template<bool full, typename IterM, typename M, typename IM>
+Array<typename M::element_type>
+take_bool_index (const M& mat, const IM& idx)
+{
+  if (full || idx.rows () == mat.rows ())
+    return take_bool_with_index<full, IterM> (mat, idx);
+  else
+    return take_bool_with_index<full, IterM> (
+        mat, partial_sparse_reshape (idx, mat.rows ()));
+}
+
+template<typename IM>
+Array<bool>
+bool_index (const PermMatrix& mat, const IM& idx)
+{
+  return take_bool_index<false, perm_iterator> (mat, idx);
+}
+
+template<typename ELT_T, typename IM>
+Array<ELT_T>
+bool_index (const DiagArray2<ELT_T>& mat, const IM& idx)
+{
+  return take_bool_index<false, diag_iterator<ELT_T> > (mat, idx);
+}
+
+template<typename ELT_T, typename IM>
+Array<ELT_T>
+bool_index (const Array<ELT_T>& mat, const IM& idx)
+{
+  return take_bool_index<true, array_iterator<ELT_T, false> > (mat, idx);
+}
+
+template<typename ELT_T, typename IM>
+Sparse<ELT_T>
+bool_index (const Sparse<ELT_T>& mat, const IM& idx)
+{
+  const Array<ELT_T> res = take_bool_index<false, sparse_iterator<ELT_T> > (
+      mat, idx);
+  return Sparse<ELT_T> (res);
+}
+
+
+template<typename M>
+struct mwrapper
+{
+  template<typename IM>
+  struct idx_caller
+  {
+    octave_value_list
+    operator() (const IM& arg, const M& into)
+    {
+      octave_value_list res (1);
+      res(0) = bool_index (into, arg);
+      return res;
+    }
+  };
+
+  // Dispatch call to bool_index for type of idx.  The nested functor allows
+  // for multiple template parameters (since dispatch assumes its template
+  // argument has exactly one template parameter)
+  static octave_value
+  do_call (const M& mat, const octave_value& idx)
+  {
+    octave_value_list res = dispatch<idx_caller> (idx, mat, "bool_index");
+    return res(0);
+  }
+};
+
+template<typename M>
+octave_value
+call_bool_index (const M& mat, const octave_value& idx)
+{
+  return mwrapper<M>::do_call (mat, idx);
+}
+
+#endif
--- a/liboctave/util/module.mk	Mon Jul 14 13:07:59 2014 -0600
+++ b/liboctave/util/module.mk	Fri Jul 25 13:39:31 2014 -0600
@@ -3,6 +3,7 @@
 
 UTIL_INC = \
   util/action-container.h \
+  util/array-iter-decl.h \
   util/base-list.h \
   util/byte-swap.h \
   util/caseless-str.h \
@@ -10,12 +11,16 @@
   util/cmd-hist.h \
   util/data-conv.h \
   util/direction.h \
+  util/dv-utils.h \
   util/find.h \
   util/functor.h \
   util/glob-match.h \
+  util/idx-bounds.h \
   util/interp-idx.h \
+  util/is-zero.h \
   util/lo-array-gripes.h \
   util/lo-cutils.h \
+  util/logical-index.h \
   util/lo-ieee.h \
   util/lo-macros.h \
   util/lo-math.h \
@@ -60,6 +65,7 @@
   util/cmd-edit.cc \
   util/cmd-hist.cc \
   util/data-conv.cc \
+  util/dv-utils.cc \
   util/glob-match.cc \
   util/lo-array-gripes.cc \
   util/lo-ieee.cc \
--- a/liboctave/util/nz-iterators.h	Mon Jul 14 13:07:59 2014 -0600
+++ b/liboctave/util/nz-iterators.h	Fri Jul 25 13:39:31 2014 -0600
@@ -22,6 +22,9 @@
 #if !defined (octave_nz_iterators_h)
 #define octave_nz_iterators_h 1
 
+#include <cassert>
+
+#include "array-iter-decl.h"
 #include "interp-idx.h"
 #include "oct-inttypes.h"
 #include "Array.h"
@@ -29,6 +32,7 @@
 #include "PermMatrix.h"
 #include "Sparse.h"
 #include "direction.h"
+#include "is-zero.h"
 
 // This file contains generic column-major iterators over
 // the nonzero elements of any array or matrix.  If you have a matrix mat
@@ -61,67 +65,104 @@
 //   T elem = iter.data();
 //   // ... Do something with these
 // }
-//
-// Note that array_iter for indexing over full matrices also includes
-// a iter.flat_index () method which returns an octave_idx_type.
-//
-// The other iterators to not have a flat_index() method because they
-// risk overflowing octave_idx_type.  It is recommended you take care
-// to implement your function in a way that accounts for this problem.
-//
-// FIXME: I'd like to add in these
-// default no-parameter versions of
-// begin() and step() to each of the
-// classes.  But the C++ compiler complains
-// because apparently I'm not allowed to overload
-// templated methods with non-templated ones.  Any
-// ideas for work-arounds?
-//
-//#define INCLUDE_DEFAULT_STEPS \
-//  void begin (void) \
-//    { \
-//      dir_handler<FORWARD> dirc; \
-//      begin (dirc); \
-//    } \
-//  void step (void) \
-//    { \
-//      dir_handler<FORWARD> dirc; \
-//      step (dirc); \
-//    } \
-//  bool finished (void) const \
-//    { \
-//      dir_handler<FORWARD> dirc; \
-//      return finished (dirc); \
-//    }
+
+struct rowcol
+{
+  rowcol (void) { }
+
+  rowcol (octave_idx_type i, const dim_vector& dims)
+  {
+#if defined(BOUNDS_CHECKING)
+    check_index(i, dims);
+#endif
+    row = i % dims(0);
+    col = i / dims(0);
+  }
+
+  rowcol (octave_idx_type rowj, octave_idx_type coli, const dim_vector& dims)
+  {
+#if defined(BOUNDS_CHECKING)
+    check_index(rowj, coli, dims);
+#endif
+    row = rowj;
+    col = coli;
+  }
+
+  octave_idx_type row;
+  octave_idx_type col;
+};
+
+// Default (non-linear) implementation of nz_iterator.interp_idx
+template <typename Iter>
+inline double
+default_interp_idx(const Iter& it)
+{
+  return to_interp_idx (it.row (), it.col (), it.dims);
+}
 
-// A generic method for checking if some element of a matrix with
-// element type T is zero.
-template<typename T>
-bool
-is_zero (T t)
+// Default (non-linear) implementation of nz_iterator.flat_idx
+template <typename Iter>
+inline octave_idx_type
+default_flat_idx (const Iter& it)
+{
+  return to_flat_idx (it.row (), it.col (), it.dims);
+}
+
+// Default (non-linear) implementation of
+// nz_iterator.skip_ahead(octave_idx_type)
+template <typename Iter>
+inline bool
+default_skip_ahead (Iter& it, octave_idx_type i)
+{
+  return it.skip_ahead (i % it.dims(0), i / it.dims(0));
+}
+
+
+template <typename M>
+class nz_iterator
 {
-  return t == static_cast<T> (0);
-}
+protected:
+  const M mat;
+
+public:
+  const dim_vector dims;
+  nz_iterator (const M& arg_mat) : mat (arg_mat), dims (arg_mat.dims ()) { }
+
+  // Because of the abundance of templates, this interface cannot be specified
+  // using pure virtual methods.  But here's what each method does:
+  //
+  // begin(dirc): moves the iterator to the beginning or the end
+  // finished(dirc): checks whether the iterator has advanced to the end or the
+  // beginning
+  // row, col, and data: return the current row, column, and element
+  // interp_idx and flat_idx: return the current linear index (interp_idx as
+  // a double the way the interpreter sees it and flat_idx as an
+  // octave_idx_type)
+  // step(dirc): steps forward to the next nonzero element
+  // skip_ahead(args): skips forward to the nearest nonzero element after
+  // (including) args.  Returns true iff the position at args is nonzero
+};
 
 // An iterator over full arrays.  When the number of dimensions exceeds
 // 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1)
 //
 // This mimics the behavior of the "find" method (both in Octave and Matlab)
 // on many-dimensional matrices.
+//
+// userc indicates whether to (true) track row and column values or (false)
+// compute them on the fly as needed
 
-template<typename T>
-class array_iterator
+template<typename T, bool userc>
+class array_iterator : public nz_iterator<Array<T> >
 {
 private:
-  const Array<T>& mat;
+  typedef nz_iterator<Array<T> > Base;
 
-  //Actual total number of columns = mat.dims().numel(1)
-  //can be different from length of row dimension
   const octave_idx_type totcols;
   const octave_idx_type numels;
 
-  octave_idx_type coli;
-  octave_idx_type rowj;
+  rowcol myrc;
+
   octave_idx_type my_idx;
 
   template<direction dir>
@@ -129,11 +170,14 @@
   step_once (dir_handler<dir> dirc)
   {
     my_idx += dir;
-    rowj += dir;
-    if (dirc.is_ended (rowj, mat.rows ()))
+    if (userc)
       {
-        rowj = dirc.begin (mat.rows ());
-        coli += dir;
+        myrc.row += dir;
+        if (dirc.is_ended (myrc.row, Base::mat.rows ()))
+          {
+            myrc.row = dirc.begin (Base::mat.rows ());
+            myrc.col += dir;
+          }
       }
   }
 
@@ -147,35 +191,66 @@
       }
   }
 
+  template <direction dir>
+  bool
+  next_nz (dir_handler<dir> dirc)
+  {
+    if (is_zero (data ()))
+      {
+        step (dirc);
+        return false;
+      }
+    else
+      return true;
+  }
+
+  bool
+  skip_ahead (octave_idx_type i, const rowcol& rc)
+  {
+    if (i < my_idx)
+      return false;
+    my_idx = i;
+    if (userc)
+      myrc = rc;
+    return next_nz (fdirc);
+  }
 
 public:
   array_iterator (const Array<T>& arg_mat)
-      : mat (arg_mat), totcols (arg_mat.dims ().numel (1)), numels (
+      : nz_iterator<Array<T> > (arg_mat), totcols (Base::dims.numel (1)), numels (
           totcols * arg_mat.rows ())
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    coli = dirc.begin (totcols);
-    rowj = dirc.begin (mat.rows ());
-    my_idx = dirc.begin (mat.numel ());
+    if(userc)
+      {
+        myrc.col = dirc.begin (totcols);
+        myrc.row = dirc.begin (Base::mat.rows ());
+      }
+    my_idx = dirc.begin (Base::mat.numel ());
     move_to_nz (dirc);
   }
 
   octave_idx_type
   col (void) const
   {
-    return coli;
+    if (userc)
+      return myrc.col;
+    else
+      return my_idx / Base::mat.rows ();
   }
   octave_idx_type
   row (void) const
   {
-    return rowj;
+    if (userc)
+      return myrc.row;
+    else
+      return my_idx % Base::mat.rows ();
   }
   double
   interp_idx (void) const
@@ -190,7 +265,7 @@
   T
   data (void) const
   {
-    return mat.elem (my_idx);
+    return Base::mat.elem (my_idx);
   }
 
   template<direction dir>
@@ -206,13 +281,41 @@
   {
     return dirc.is_ended (my_idx, numels);
   }
+
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return skip_ahead (i, rowcol (i, Base::dims));
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
+  {
+    return skip_ahead (coli * Base::mat.rows () + rowj,
+                       rowcol (rowj, coli, Base::dims));
+  }
+
+  bool
+  skip_ahead (const Array<octave_idx_type>& idxs)
+  {
+    //TODO Check bounds
+    octave_idx_type rowj = idxs(0);
+    octave_idx_type coli = 0;
+    for(int i = idxs.numel () - 1; i > 1; --i) {
+        coli += idxs(i);
+        coli *= Base::dims(i);
+    }
+    coli += idxs(1);
+    return skip_ahead (rowj, coli);
+  }
 };
 
 template<typename T>
-class sparse_iterator
+class sparse_iterator : public nz_iterator<Sparse<T> >
 {
 private:
-  const Sparse<T>& mat;
+  typedef nz_iterator<Sparse<T> > Base;
+
   octave_idx_type coli;
   octave_idx_type my_idx;
 
@@ -221,32 +324,28 @@
   adjust_col (dir_handler<dir> dirc)
   {
     while (!finished (dirc)
-        && dirc.is_ended (my_idx, mat.cidx (coli), mat.cidx (coli + 1)))
+        && dirc.is_ended (my_idx, Base::mat.cidx (coli), Base::mat.cidx (coli + 1)))
       coli += dir;
   }
 
+  void jump_to_row (octave_idx_type rowj);
+
 public:
   sparse_iterator (const Sparse<T>& arg_mat) :
-      mat (arg_mat)
+      nz_iterator<Sparse<T> > (arg_mat)
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    coli = dirc.begin (mat.cols ());
-    my_idx = dirc.begin (mat.nnz ());
+    coli = dirc.begin (Base::mat.cols ());
+    my_idx = dirc.begin (Base::mat.nnz ());
     adjust_col (dirc);
   }
 
-  double
-  interp_idx (void) const
-  {
-    return to_interp_idx (row (), col (), mat.dims ());
-  }
   octave_idx_type
   col (void) const
   {
@@ -255,12 +354,13 @@
   octave_idx_type
   row (void) const
   {
-    return mat.ridx (my_idx);
+    return Base::mat.ridx (my_idx);
   }
+
   T
   data (void) const
   {
-    return mat.data (my_idx);
+    return Base::mat.data (my_idx);
   }
   template<direction dir>
   void
@@ -273,15 +373,49 @@
   bool
   finished (dir_handler<dir> dirc) const
   {
-    return dirc.is_ended (coli, mat.cols ());
+    return dirc.is_ended (coli, Base::mat.cols ());
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type arg_coli)
+  {
+    //TODO Check bounds
+    if (arg_coli < coli || (arg_coli == coli && rowj < this->row ()))
+      {
+        return false;
+      }
+    else if (arg_coli > coli)
+      {
+        coli = arg_coli;
+        my_idx = Base::mat.cidx (arg_coli);
+      }
+    jump_to_row (rowj);
+    return coli == arg_coli && this->row () == rowj;
+  }
+
+  double
+  interp_idx (void) const
+  {
+    return default_interp_idx(*this);
+  }
+  octave_idx_type
+  flat_idx (void) const
+  {
+    return default_flat_idx(*this);
+  }
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return default_skip_ahead (*this, i);
   }
 };
 
 template<typename T>
-class diag_iterator
+class diag_iterator : public nz_iterator <DiagArray2<T> >
 {
 private:
-  const DiagArray2<T>& mat;
+  typedef nz_iterator <DiagArray2<T> > Base;
+
   octave_idx_type my_idx;
 
   template <direction dir>
@@ -296,25 +430,19 @@
 
 public:
   diag_iterator (const DiagArray2<T>& arg_mat) :
-      mat (arg_mat)
+      nz_iterator<DiagArray2<T> > (arg_mat)
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    my_idx = dirc.begin (mat.diag_length ());
+    my_idx = dirc.begin (Base::mat.diag_length ());
     move_to_nz (dirc);
   }
 
-  double
-  interp_idx (void) const
-  {
-    return to_interp_idx (row (), col (), mat.dims ());
-  }
   octave_idx_type
   col (void) const
   {
@@ -325,10 +453,11 @@
   {
     return my_idx;
   }
+
   T
   data (void) const
   {
-    return mat.dgelem (my_idx);
+    return Base::mat.dgelem (my_idx);
   }
   template<direction dir>
   void
@@ -341,37 +470,58 @@
   bool
   finished (dir_handler<dir> dirc) const
   {
-    return dirc.is_ended (my_idx, mat.diag_length ());
+    return dirc.is_ended (my_idx, Base::mat.diag_length ());
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
+  {
+    if (coli < my_idx)
+      return false;
+    my_idx = coli + (rowj > coli);
+    move_to_nz (fdirc);
+    return rowj == coli && coli == my_idx;
+  }
+
+  double
+  interp_idx (void) const
+  {
+    return default_interp_idx(*this);
+  }
+  octave_idx_type
+  flat_idx (void) const
+  {
+    return default_flat_idx(*this);
+  }
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return default_skip_ahead (*this, i);
   }
 };
 
-class perm_iterator
+class perm_iterator : public nz_iterator<PermMatrix>
 {
 private:
-  const PermMatrix& mat;
+  typedef nz_iterator<PermMatrix> Base;
+
   octave_idx_type my_idx;
 
 public:
   perm_iterator (const PermMatrix& arg_mat) :
-      mat (arg_mat)
+      nz_iterator<PermMatrix> (arg_mat)
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    my_idx = dirc.begin (mat.cols ());
+    my_idx = dirc.begin (Base::mat.cols ());
   }
 
   octave_idx_type
-  interp_idx (void) const
-  {
-    return to_interp_idx (row (), col (), mat.dims ());
-  }
-  octave_idx_type
   col (void) const
   {
     return my_idx;
@@ -379,8 +529,9 @@
   octave_idx_type
   row (void) const
   {
-    return mat.perm_elem (my_idx);
+    return Base::mat.perm_elem (my_idx);
   }
+
   bool
   data (void) const
   {
@@ -396,8 +547,68 @@
   bool
   finished (dir_handler<dir> dirc) const
   {
-    return dirc.is_ended (my_idx, mat.rows ());
+    return dirc.is_ended (my_idx, Base::mat.rows ());
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
+  {
+    //TODO Check bounds
+    if (coli < my_idx || rowj < this->row ())
+      return false;
+    my_idx = coli + (rowj > Base::mat.perm_elem (coli));
+    return my_idx == coli && rowj == this->row ();
+  }
+
+  double
+  interp_idx (void) const
+  {
+    return default_interp_idx(*this);
+  }
+  octave_idx_type
+  flat_idx (void) const
+  {
+    return default_flat_idx(*this);
+  }
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return default_skip_ahead (*this, i);
   }
 };
 
+// Uses a one-sided binary search to move to the next element in this column
+// whose row is at least rowj.  The one-sided binary search guarantees
+// O(log(rowj - currentRow)) time to find it.
+template<typename T>
+void
+sparse_iterator<T>::jump_to_row (octave_idx_type rowj)
+{
+  octave_idx_type ub = Base::mat.cidx (coli + 1);
+  octave_idx_type lo = my_idx - 1;
+  octave_idx_type hi = my_idx;
+  octave_idx_type hidiff = 1;
+  while (Base::mat.ridx (hi) < rowj)
+    {
+      lo = hi;
+      hidiff *= 2;
+      hi += hidiff;
+      if (hi >= ub)
+        {
+          hi = ub;
+          break;
+        }
+    }
+  while (hi - lo > 1)
+    {
+      octave_idx_type mid = (lo + hi) / 2;
+      if (Base::mat.ridx (mid) < rowj)
+        lo = mid;
+      else
+        hi = mid;
+    }
+  my_idx = hi;
+  adjust_col (fdirc);
+}
+
 #endif
--- a/test/logical-index.tst	Mon Jul 14 13:07:59 2014 -0600
+++ b/test/logical-index.tst	Fri Jul 25 13:39:31 2014 -0600
@@ -30,13 +30,18 @@
 %!assert (isempty (a(logical ([0,0,0,0]))))
 %!assert (a(logical ([1,1,1,1])), [9,8,7,6])
 %!assert (a(logical ([0,1,1,0])), [8,7])
+%!assert (a(logical ([0;1;1;0])), [8,7])
 %!assert (a(logical ([1,1])), [9,8])
 
+%! a = permute(a,[1,3,2]);
+%!assert (a(logical ([0,1,1,0])), permute([8,7],[1,3,2]))
+
 %!shared a
 %! a = [9,8;7,6];
-%!assert (isempty (a(logical ([0,0,0,0]))))
+%!assert (a(logical ([0,0,0,0])), zeros(1,0))
 %!assert (a(logical ([1,1,1,1])), [9,7,8,6])
 %!assert (a(logical ([0,1,1,0])), [7,8])
+%!assert (a(logical ([0;1;1;0])), [7;8])
 %!assert (a(logical (0:1),logical (0:1)), 6)
 %!assert (a(logical (0:1),2:-1:1), [6,7])
 %!assert (a(logical (0:1),logical ([0,1])), 6)
@@ -71,3 +76,21 @@
 %!assert (a(logical ([1,1]),1), [9;7])
 %!assert (a(logical ([1,1]),logical ([1,1])), [9,8;7,6])
 
+%!assert (a(logical (permute([1,1,1,1],[1,3,2]))), [9;7;8;6])
+%!assert (a(logical (permute([0,1,1,0],[1,3,2]))), [7;8])
+%!assert (a(logical([])), [])
+%!assert (a(false), [])
+%!assert (a(false(0,0,1)), [])
+%!assert (a(false(2,2)),zeros(0,1))
+%!assert (a(false(1,4)),zeros(1,0))
+%!assert (a(false(4,1)),zeros(0,1))
+%!assert (a(false(1,1,4)),zeros(0,1))
+
+%!shared v,a,b,c
+%! v = sparse((1:100000).');
+%! v = v .* mod(v,2);
+%! a = sparse(diag(v));
+%! b = logical(speye(100000));
+%! c = [b(:,1:2:end); b(:,2:2:end)];
+%!assert (a(b), v)
+%!assert (a(c), v)