diff liboctave/util/nz-iterators.h @ 19010:3fb030666878 draft default tip dspies

Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior
author David Spies <dnspies@gmail.com>
date Fri, 25 Jul 2014 13:39:31 -0600
parents 8d47ce2053f2
children
line wrap: on
line diff
--- a/liboctave/util/nz-iterators.h	Mon Jul 14 13:07:59 2014 -0600
+++ b/liboctave/util/nz-iterators.h	Fri Jul 25 13:39:31 2014 -0600
@@ -22,6 +22,9 @@
 #if !defined (octave_nz_iterators_h)
 #define octave_nz_iterators_h 1
 
+#include <cassert>
+
+#include "array-iter-decl.h"
 #include "interp-idx.h"
 #include "oct-inttypes.h"
 #include "Array.h"
@@ -29,6 +32,7 @@
 #include "PermMatrix.h"
 #include "Sparse.h"
 #include "direction.h"
+#include "is-zero.h"
 
 // This file contains generic column-major iterators over
 // the nonzero elements of any array or matrix.  If you have a matrix mat
@@ -61,67 +65,104 @@
 //   T elem = iter.data();
 //   // ... Do something with these
 // }
-//
-// Note that array_iter for indexing over full matrices also includes
-// a iter.flat_index () method which returns an octave_idx_type.
-//
-// The other iterators to not have a flat_index() method because they
-// risk overflowing octave_idx_type.  It is recommended you take care
-// to implement your function in a way that accounts for this problem.
-//
-// FIXME: I'd like to add in these
-// default no-parameter versions of
-// begin() and step() to each of the
-// classes.  But the C++ compiler complains
-// because apparently I'm not allowed to overload
-// templated methods with non-templated ones.  Any
-// ideas for work-arounds?
-//
-//#define INCLUDE_DEFAULT_STEPS \
-//  void begin (void) \
-//    { \
-//      dir_handler<FORWARD> dirc; \
-//      begin (dirc); \
-//    } \
-//  void step (void) \
-//    { \
-//      dir_handler<FORWARD> dirc; \
-//      step (dirc); \
-//    } \
-//  bool finished (void) const \
-//    { \
-//      dir_handler<FORWARD> dirc; \
-//      return finished (dirc); \
-//    }
+
+struct rowcol
+{
+  rowcol (void) { }
+
+  rowcol (octave_idx_type i, const dim_vector& dims)
+  {
+#if defined(BOUNDS_CHECKING)
+    check_index(i, dims);
+#endif
+    row = i % dims(0);
+    col = i / dims(0);
+  }
+
+  rowcol (octave_idx_type rowj, octave_idx_type coli, const dim_vector& dims)
+  {
+#if defined(BOUNDS_CHECKING)
+    check_index(rowj, coli, dims);
+#endif
+    row = rowj;
+    col = coli;
+  }
+
+  octave_idx_type row;
+  octave_idx_type col;
+};
+
+// Default (non-linear) implementation of nz_iterator.interp_idx
+template <typename Iter>
+inline double
+default_interp_idx(const Iter& it)
+{
+  return to_interp_idx (it.row (), it.col (), it.dims);
+}
 
-// A generic method for checking if some element of a matrix with
-// element type T is zero.
-template<typename T>
-bool
-is_zero (T t)
+// Default (non-linear) implementation of nz_iterator.flat_idx
+template <typename Iter>
+inline octave_idx_type
+default_flat_idx (const Iter& it)
+{
+  return to_flat_idx (it.row (), it.col (), it.dims);
+}
+
+// Default (non-linear) implementation of
+// nz_iterator.skip_ahead(octave_idx_type)
+template <typename Iter>
+inline bool
+default_skip_ahead (Iter& it, octave_idx_type i)
+{
+  return it.skip_ahead (i % it.dims(0), i / it.dims(0));
+}
+
+
+template <typename M>
+class nz_iterator
 {
-  return t == static_cast<T> (0);
-}
+protected:
+  const M mat;
+
+public:
+  const dim_vector dims;
+  nz_iterator (const M& arg_mat) : mat (arg_mat), dims (arg_mat.dims ()) { }
+
+  // Because of the abundance of templates, this interface cannot be specified
+  // using pure virtual methods.  But here's what each method does:
+  //
+  // begin(dirc): moves the iterator to the beginning or the end
+  // finished(dirc): checks whether the iterator has advanced to the end or the
+  // beginning
+  // row, col, and data: return the current row, column, and element
+  // interp_idx and flat_idx: return the current linear index (interp_idx as
+  // a double the way the interpreter sees it and flat_idx as an
+  // octave_idx_type)
+  // step(dirc): steps forward to the next nonzero element
+  // skip_ahead(args): skips forward to the nearest nonzero element after
+  // (including) args.  Returns true iff the position at args is nonzero
+};
 
 // An iterator over full arrays.  When the number of dimensions exceeds
 // 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1)
 //
 // This mimics the behavior of the "find" method (both in Octave and Matlab)
 // on many-dimensional matrices.
+//
+// userc indicates whether to (true) track row and column values or (false)
+// compute them on the fly as needed
 
-template<typename T>
-class array_iterator
+template<typename T, bool userc>
+class array_iterator : public nz_iterator<Array<T> >
 {
 private:
-  const Array<T>& mat;
+  typedef nz_iterator<Array<T> > Base;
 
-  //Actual total number of columns = mat.dims().numel(1)
-  //can be different from length of row dimension
   const octave_idx_type totcols;
   const octave_idx_type numels;
 
-  octave_idx_type coli;
-  octave_idx_type rowj;
+  rowcol myrc;
+
   octave_idx_type my_idx;
 
   template<direction dir>
@@ -129,11 +170,14 @@
   step_once (dir_handler<dir> dirc)
   {
     my_idx += dir;
-    rowj += dir;
-    if (dirc.is_ended (rowj, mat.rows ()))
+    if (userc)
       {
-        rowj = dirc.begin (mat.rows ());
-        coli += dir;
+        myrc.row += dir;
+        if (dirc.is_ended (myrc.row, Base::mat.rows ()))
+          {
+            myrc.row = dirc.begin (Base::mat.rows ());
+            myrc.col += dir;
+          }
       }
   }
 
@@ -147,35 +191,66 @@
       }
   }
 
+  template <direction dir>
+  bool
+  next_nz (dir_handler<dir> dirc)
+  {
+    if (is_zero (data ()))
+      {
+        step (dirc);
+        return false;
+      }
+    else
+      return true;
+  }
+
+  bool
+  skip_ahead (octave_idx_type i, const rowcol& rc)
+  {
+    if (i < my_idx)
+      return false;
+    my_idx = i;
+    if (userc)
+      myrc = rc;
+    return next_nz (fdirc);
+  }
 
 public:
   array_iterator (const Array<T>& arg_mat)
-      : mat (arg_mat), totcols (arg_mat.dims ().numel (1)), numels (
+      : nz_iterator<Array<T> > (arg_mat), totcols (Base::dims.numel (1)), numels (
           totcols * arg_mat.rows ())
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    coli = dirc.begin (totcols);
-    rowj = dirc.begin (mat.rows ());
-    my_idx = dirc.begin (mat.numel ());
+    if(userc)
+      {
+        myrc.col = dirc.begin (totcols);
+        myrc.row = dirc.begin (Base::mat.rows ());
+      }
+    my_idx = dirc.begin (Base::mat.numel ());
     move_to_nz (dirc);
   }
 
   octave_idx_type
   col (void) const
   {
-    return coli;
+    if (userc)
+      return myrc.col;
+    else
+      return my_idx / Base::mat.rows ();
   }
   octave_idx_type
   row (void) const
   {
-    return rowj;
+    if (userc)
+      return myrc.row;
+    else
+      return my_idx % Base::mat.rows ();
   }
   double
   interp_idx (void) const
@@ -190,7 +265,7 @@
   T
   data (void) const
   {
-    return mat.elem (my_idx);
+    return Base::mat.elem (my_idx);
   }
 
   template<direction dir>
@@ -206,13 +281,41 @@
   {
     return dirc.is_ended (my_idx, numels);
   }
+
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return skip_ahead (i, rowcol (i, Base::dims));
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
+  {
+    return skip_ahead (coli * Base::mat.rows () + rowj,
+                       rowcol (rowj, coli, Base::dims));
+  }
+
+  bool
+  skip_ahead (const Array<octave_idx_type>& idxs)
+  {
+    //TODO Check bounds
+    octave_idx_type rowj = idxs(0);
+    octave_idx_type coli = 0;
+    for(int i = idxs.numel () - 1; i > 1; --i) {
+        coli += idxs(i);
+        coli *= Base::dims(i);
+    }
+    coli += idxs(1);
+    return skip_ahead (rowj, coli);
+  }
 };
 
 template<typename T>
-class sparse_iterator
+class sparse_iterator : public nz_iterator<Sparse<T> >
 {
 private:
-  const Sparse<T>& mat;
+  typedef nz_iterator<Sparse<T> > Base;
+
   octave_idx_type coli;
   octave_idx_type my_idx;
 
@@ -221,32 +324,28 @@
   adjust_col (dir_handler<dir> dirc)
   {
     while (!finished (dirc)
-        && dirc.is_ended (my_idx, mat.cidx (coli), mat.cidx (coli + 1)))
+        && dirc.is_ended (my_idx, Base::mat.cidx (coli), Base::mat.cidx (coli + 1)))
       coli += dir;
   }
 
+  void jump_to_row (octave_idx_type rowj);
+
 public:
   sparse_iterator (const Sparse<T>& arg_mat) :
-      mat (arg_mat)
+      nz_iterator<Sparse<T> > (arg_mat)
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    coli = dirc.begin (mat.cols ());
-    my_idx = dirc.begin (mat.nnz ());
+    coli = dirc.begin (Base::mat.cols ());
+    my_idx = dirc.begin (Base::mat.nnz ());
     adjust_col (dirc);
   }
 
-  double
-  interp_idx (void) const
-  {
-    return to_interp_idx (row (), col (), mat.dims ());
-  }
   octave_idx_type
   col (void) const
   {
@@ -255,12 +354,13 @@
   octave_idx_type
   row (void) const
   {
-    return mat.ridx (my_idx);
+    return Base::mat.ridx (my_idx);
   }
+
   T
   data (void) const
   {
-    return mat.data (my_idx);
+    return Base::mat.data (my_idx);
   }
   template<direction dir>
   void
@@ -273,15 +373,49 @@
   bool
   finished (dir_handler<dir> dirc) const
   {
-    return dirc.is_ended (coli, mat.cols ());
+    return dirc.is_ended (coli, Base::mat.cols ());
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type arg_coli)
+  {
+    //TODO Check bounds
+    if (arg_coli < coli || (arg_coli == coli && rowj < this->row ()))
+      {
+        return false;
+      }
+    else if (arg_coli > coli)
+      {
+        coli = arg_coli;
+        my_idx = Base::mat.cidx (arg_coli);
+      }
+    jump_to_row (rowj);
+    return coli == arg_coli && this->row () == rowj;
+  }
+
+  double
+  interp_idx (void) const
+  {
+    return default_interp_idx(*this);
+  }
+  octave_idx_type
+  flat_idx (void) const
+  {
+    return default_flat_idx(*this);
+  }
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return default_skip_ahead (*this, i);
   }
 };
 
 template<typename T>
-class diag_iterator
+class diag_iterator : public nz_iterator <DiagArray2<T> >
 {
 private:
-  const DiagArray2<T>& mat;
+  typedef nz_iterator <DiagArray2<T> > Base;
+
   octave_idx_type my_idx;
 
   template <direction dir>
@@ -296,25 +430,19 @@
 
 public:
   diag_iterator (const DiagArray2<T>& arg_mat) :
-      mat (arg_mat)
+      nz_iterator<DiagArray2<T> > (arg_mat)
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    my_idx = dirc.begin (mat.diag_length ());
+    my_idx = dirc.begin (Base::mat.diag_length ());
     move_to_nz (dirc);
   }
 
-  double
-  interp_idx (void) const
-  {
-    return to_interp_idx (row (), col (), mat.dims ());
-  }
   octave_idx_type
   col (void) const
   {
@@ -325,10 +453,11 @@
   {
     return my_idx;
   }
+
   T
   data (void) const
   {
-    return mat.dgelem (my_idx);
+    return Base::mat.dgelem (my_idx);
   }
   template<direction dir>
   void
@@ -341,37 +470,58 @@
   bool
   finished (dir_handler<dir> dirc) const
   {
-    return dirc.is_ended (my_idx, mat.diag_length ());
+    return dirc.is_ended (my_idx, Base::mat.diag_length ());
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
+  {
+    if (coli < my_idx)
+      return false;
+    my_idx = coli + (rowj > coli);
+    move_to_nz (fdirc);
+    return rowj == coli && coli == my_idx;
+  }
+
+  double
+  interp_idx (void) const
+  {
+    return default_interp_idx(*this);
+  }
+  octave_idx_type
+  flat_idx (void) const
+  {
+    return default_flat_idx(*this);
+  }
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return default_skip_ahead (*this, i);
   }
 };
 
-class perm_iterator
+class perm_iterator : public nz_iterator<PermMatrix>
 {
 private:
-  const PermMatrix& mat;
+  typedef nz_iterator<PermMatrix> Base;
+
   octave_idx_type my_idx;
 
 public:
   perm_iterator (const PermMatrix& arg_mat) :
-      mat (arg_mat)
+      nz_iterator<PermMatrix> (arg_mat)
   {
-    dir_handler<FORWARD> dirc;
-    begin (dirc);
+    begin (fdirc);
   }
 
   template<direction dir>
   void
   begin (dir_handler<dir> dirc)
   {
-    my_idx = dirc.begin (mat.cols ());
+    my_idx = dirc.begin (Base::mat.cols ());
   }
 
   octave_idx_type
-  interp_idx (void) const
-  {
-    return to_interp_idx (row (), col (), mat.dims ());
-  }
-  octave_idx_type
   col (void) const
   {
     return my_idx;
@@ -379,8 +529,9 @@
   octave_idx_type
   row (void) const
   {
-    return mat.perm_elem (my_idx);
+    return Base::mat.perm_elem (my_idx);
   }
+
   bool
   data (void) const
   {
@@ -396,8 +547,68 @@
   bool
   finished (dir_handler<dir> dirc) const
   {
-    return dirc.is_ended (my_idx, mat.rows ());
+    return dirc.is_ended (my_idx, Base::mat.rows ());
+  }
+
+  bool
+  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
+  {
+    //TODO Check bounds
+    if (coli < my_idx || rowj < this->row ())
+      return false;
+    my_idx = coli + (rowj > Base::mat.perm_elem (coli));
+    return my_idx == coli && rowj == this->row ();
+  }
+
+  double
+  interp_idx (void) const
+  {
+    return default_interp_idx(*this);
+  }
+  octave_idx_type
+  flat_idx (void) const
+  {
+    return default_flat_idx(*this);
+  }
+  bool
+  skip_ahead (octave_idx_type i)
+  {
+    return default_skip_ahead (*this, i);
   }
 };
 
+// Uses a one-sided binary search to move to the next element in this column
+// whose row is at least rowj.  The one-sided binary search guarantees
+// O(log(rowj - currentRow)) time to find it.
+template<typename T>
+void
+sparse_iterator<T>::jump_to_row (octave_idx_type rowj)
+{
+  octave_idx_type ub = Base::mat.cidx (coli + 1);
+  octave_idx_type lo = my_idx - 1;
+  octave_idx_type hi = my_idx;
+  octave_idx_type hidiff = 1;
+  while (Base::mat.ridx (hi) < rowj)
+    {
+      lo = hi;
+      hidiff *= 2;
+      hi += hidiff;
+      if (hi >= ub)
+        {
+          hi = ub;
+          break;
+        }
+    }
+  while (hi - lo > 1)
+    {
+      octave_idx_type mid = (lo + hi) / 2;
+      if (Base::mat.ridx (mid) < rowj)
+        lo = mid;
+      else
+        hi = mid;
+    }
+  my_idx = hi;
+  adjust_col (fdirc);
+}
+
 #endif