Mercurial > octave-dspies
diff liboctave/util/nz-iterators.h @ 19010:3fb030666878 draft default tip dspies
Added special-case logical-indexing function
* logical-index.h (New file) : Logical-indexing function. May be called on
octave_value types via call_bool_index
* nz-iterators.h : Add base-class nz_iterator for iterator types. Array has
template bool for whether to internally store row-col or compute on the fly
Add skip_ahead method which skips forward to the next nonzero after its
argument
Add flat_index for computing octave_idx_type index of current position (with
assertion failure in the case of overflow)
Move is_zero to separate file
* ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc
(do_index_op): Add call to call_bool_index in logical-index.h
* Array.h : Move forward-declaration for array_iterator to separate header file
* dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx)
* array-iter-decl.h (New file): Header file for forward declaration of
array-iterator
* direction.h : Add constants fdirc and bdirc to avoid having to reconstruct
them
* dv-utils.h, dv-utils.cc (New files) :
Utility functions for querying and constructing dim-vectors
* idx-bounds.h (New file) :
Utility constants and functions for determining whether things will overflow
the maximum allowed bounds
* interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear
index of octave_idx_type
* is-zero.h (New file) : Function for determining whether an element is zero
* logical-index.tst : Add tests for correct return-value dimensions and large
sparse matrix behavior
author | David Spies <dnspies@gmail.com> |
---|---|
date | Fri, 25 Jul 2014 13:39:31 -0600 |
parents | 8d47ce2053f2 |
children |
line wrap: on
line diff
--- a/liboctave/util/nz-iterators.h Mon Jul 14 13:07:59 2014 -0600 +++ b/liboctave/util/nz-iterators.h Fri Jul 25 13:39:31 2014 -0600 @@ -22,6 +22,9 @@ #if !defined (octave_nz_iterators_h) #define octave_nz_iterators_h 1 +#include <cassert> + +#include "array-iter-decl.h" #include "interp-idx.h" #include "oct-inttypes.h" #include "Array.h" @@ -29,6 +32,7 @@ #include "PermMatrix.h" #include "Sparse.h" #include "direction.h" +#include "is-zero.h" // This file contains generic column-major iterators over // the nonzero elements of any array or matrix. If you have a matrix mat @@ -61,67 +65,104 @@ // T elem = iter.data(); // // ... Do something with these // } -// -// Note that array_iter for indexing over full matrices also includes -// a iter.flat_index () method which returns an octave_idx_type. -// -// The other iterators to not have a flat_index() method because they -// risk overflowing octave_idx_type. It is recommended you take care -// to implement your function in a way that accounts for this problem. -// -// FIXME: I'd like to add in these -// default no-parameter versions of -// begin() and step() to each of the -// classes. But the C++ compiler complains -// because apparently I'm not allowed to overload -// templated methods with non-templated ones. Any -// ideas for work-arounds? -// -//#define INCLUDE_DEFAULT_STEPS \ -// void begin (void) \ -// { \ -// dir_handler<FORWARD> dirc; \ -// begin (dirc); \ -// } \ -// void step (void) \ -// { \ -// dir_handler<FORWARD> dirc; \ -// step (dirc); \ -// } \ -// bool finished (void) const \ -// { \ -// dir_handler<FORWARD> dirc; \ -// return finished (dirc); \ -// } + +struct rowcol +{ + rowcol (void) { } + + rowcol (octave_idx_type i, const dim_vector& dims) + { +#if defined(BOUNDS_CHECKING) + check_index(i, dims); +#endif + row = i % dims(0); + col = i / dims(0); + } + + rowcol (octave_idx_type rowj, octave_idx_type coli, const dim_vector& dims) + { +#if defined(BOUNDS_CHECKING) + check_index(rowj, coli, dims); +#endif + row = rowj; + col = coli; + } + + octave_idx_type row; + octave_idx_type col; +}; + +// Default (non-linear) implementation of nz_iterator.interp_idx +template <typename Iter> +inline double +default_interp_idx(const Iter& it) +{ + return to_interp_idx (it.row (), it.col (), it.dims); +} -// A generic method for checking if some element of a matrix with -// element type T is zero. -template<typename T> -bool -is_zero (T t) +// Default (non-linear) implementation of nz_iterator.flat_idx +template <typename Iter> +inline octave_idx_type +default_flat_idx (const Iter& it) +{ + return to_flat_idx (it.row (), it.col (), it.dims); +} + +// Default (non-linear) implementation of +// nz_iterator.skip_ahead(octave_idx_type) +template <typename Iter> +inline bool +default_skip_ahead (Iter& it, octave_idx_type i) +{ + return it.skip_ahead (i % it.dims(0), i / it.dims(0)); +} + + +template <typename M> +class nz_iterator { - return t == static_cast<T> (0); -} +protected: + const M mat; + +public: + const dim_vector dims; + nz_iterator (const M& arg_mat) : mat (arg_mat), dims (arg_mat.dims ()) { } + + // Because of the abundance of templates, this interface cannot be specified + // using pure virtual methods. But here's what each method does: + // + // begin(dirc): moves the iterator to the beginning or the end + // finished(dirc): checks whether the iterator has advanced to the end or the + // beginning + // row, col, and data: return the current row, column, and element + // interp_idx and flat_idx: return the current linear index (interp_idx as + // a double the way the interpreter sees it and flat_idx as an + // octave_idx_type) + // step(dirc): steps forward to the next nonzero element + // skip_ahead(args): skips forward to the nearest nonzero element after + // (including) args. Returns true iff the position at args is nonzero +}; // An iterator over full arrays. When the number of dimensions exceeds // 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1) // // This mimics the behavior of the "find" method (both in Octave and Matlab) // on many-dimensional matrices. +// +// userc indicates whether to (true) track row and column values or (false) +// compute them on the fly as needed -template<typename T> -class array_iterator +template<typename T, bool userc> +class array_iterator : public nz_iterator<Array<T> > { private: - const Array<T>& mat; + typedef nz_iterator<Array<T> > Base; - //Actual total number of columns = mat.dims().numel(1) - //can be different from length of row dimension const octave_idx_type totcols; const octave_idx_type numels; - octave_idx_type coli; - octave_idx_type rowj; + rowcol myrc; + octave_idx_type my_idx; template<direction dir> @@ -129,11 +170,14 @@ step_once (dir_handler<dir> dirc) { my_idx += dir; - rowj += dir; - if (dirc.is_ended (rowj, mat.rows ())) + if (userc) { - rowj = dirc.begin (mat.rows ()); - coli += dir; + myrc.row += dir; + if (dirc.is_ended (myrc.row, Base::mat.rows ())) + { + myrc.row = dirc.begin (Base::mat.rows ()); + myrc.col += dir; + } } } @@ -147,35 +191,66 @@ } } + template <direction dir> + bool + next_nz (dir_handler<dir> dirc) + { + if (is_zero (data ())) + { + step (dirc); + return false; + } + else + return true; + } + + bool + skip_ahead (octave_idx_type i, const rowcol& rc) + { + if (i < my_idx) + return false; + my_idx = i; + if (userc) + myrc = rc; + return next_nz (fdirc); + } public: array_iterator (const Array<T>& arg_mat) - : mat (arg_mat), totcols (arg_mat.dims ().numel (1)), numels ( + : nz_iterator<Array<T> > (arg_mat), totcols (Base::dims.numel (1)), numels ( totcols * arg_mat.rows ()) { - dir_handler<FORWARD> dirc; - begin (dirc); + begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { - coli = dirc.begin (totcols); - rowj = dirc.begin (mat.rows ()); - my_idx = dirc.begin (mat.numel ()); + if(userc) + { + myrc.col = dirc.begin (totcols); + myrc.row = dirc.begin (Base::mat.rows ()); + } + my_idx = dirc.begin (Base::mat.numel ()); move_to_nz (dirc); } octave_idx_type col (void) const { - return coli; + if (userc) + return myrc.col; + else + return my_idx / Base::mat.rows (); } octave_idx_type row (void) const { - return rowj; + if (userc) + return myrc.row; + else + return my_idx % Base::mat.rows (); } double interp_idx (void) const @@ -190,7 +265,7 @@ T data (void) const { - return mat.elem (my_idx); + return Base::mat.elem (my_idx); } template<direction dir> @@ -206,13 +281,41 @@ { return dirc.is_ended (my_idx, numels); } + + bool + skip_ahead (octave_idx_type i) + { + return skip_ahead (i, rowcol (i, Base::dims)); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type coli) + { + return skip_ahead (coli * Base::mat.rows () + rowj, + rowcol (rowj, coli, Base::dims)); + } + + bool + skip_ahead (const Array<octave_idx_type>& idxs) + { + //TODO Check bounds + octave_idx_type rowj = idxs(0); + octave_idx_type coli = 0; + for(int i = idxs.numel () - 1; i > 1; --i) { + coli += idxs(i); + coli *= Base::dims(i); + } + coli += idxs(1); + return skip_ahead (rowj, coli); + } }; template<typename T> -class sparse_iterator +class sparse_iterator : public nz_iterator<Sparse<T> > { private: - const Sparse<T>& mat; + typedef nz_iterator<Sparse<T> > Base; + octave_idx_type coli; octave_idx_type my_idx; @@ -221,32 +324,28 @@ adjust_col (dir_handler<dir> dirc) { while (!finished (dirc) - && dirc.is_ended (my_idx, mat.cidx (coli), mat.cidx (coli + 1))) + && dirc.is_ended (my_idx, Base::mat.cidx (coli), Base::mat.cidx (coli + 1))) coli += dir; } + void jump_to_row (octave_idx_type rowj); + public: sparse_iterator (const Sparse<T>& arg_mat) : - mat (arg_mat) + nz_iterator<Sparse<T> > (arg_mat) { - dir_handler<FORWARD> dirc; - begin (dirc); + begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { - coli = dirc.begin (mat.cols ()); - my_idx = dirc.begin (mat.nnz ()); + coli = dirc.begin (Base::mat.cols ()); + my_idx = dirc.begin (Base::mat.nnz ()); adjust_col (dirc); } - double - interp_idx (void) const - { - return to_interp_idx (row (), col (), mat.dims ()); - } octave_idx_type col (void) const { @@ -255,12 +354,13 @@ octave_idx_type row (void) const { - return mat.ridx (my_idx); + return Base::mat.ridx (my_idx); } + T data (void) const { - return mat.data (my_idx); + return Base::mat.data (my_idx); } template<direction dir> void @@ -273,15 +373,49 @@ bool finished (dir_handler<dir> dirc) const { - return dirc.is_ended (coli, mat.cols ()); + return dirc.is_ended (coli, Base::mat.cols ()); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type arg_coli) + { + //TODO Check bounds + if (arg_coli < coli || (arg_coli == coli && rowj < this->row ())) + { + return false; + } + else if (arg_coli > coli) + { + coli = arg_coli; + my_idx = Base::mat.cidx (arg_coli); + } + jump_to_row (rowj); + return coli == arg_coli && this->row () == rowj; + } + + double + interp_idx (void) const + { + return default_interp_idx(*this); + } + octave_idx_type + flat_idx (void) const + { + return default_flat_idx(*this); + } + bool + skip_ahead (octave_idx_type i) + { + return default_skip_ahead (*this, i); } }; template<typename T> -class diag_iterator +class diag_iterator : public nz_iterator <DiagArray2<T> > { private: - const DiagArray2<T>& mat; + typedef nz_iterator <DiagArray2<T> > Base; + octave_idx_type my_idx; template <direction dir> @@ -296,25 +430,19 @@ public: diag_iterator (const DiagArray2<T>& arg_mat) : - mat (arg_mat) + nz_iterator<DiagArray2<T> > (arg_mat) { - dir_handler<FORWARD> dirc; - begin (dirc); + begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { - my_idx = dirc.begin (mat.diag_length ()); + my_idx = dirc.begin (Base::mat.diag_length ()); move_to_nz (dirc); } - double - interp_idx (void) const - { - return to_interp_idx (row (), col (), mat.dims ()); - } octave_idx_type col (void) const { @@ -325,10 +453,11 @@ { return my_idx; } + T data (void) const { - return mat.dgelem (my_idx); + return Base::mat.dgelem (my_idx); } template<direction dir> void @@ -341,37 +470,58 @@ bool finished (dir_handler<dir> dirc) const { - return dirc.is_ended (my_idx, mat.diag_length ()); + return dirc.is_ended (my_idx, Base::mat.diag_length ()); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type coli) + { + if (coli < my_idx) + return false; + my_idx = coli + (rowj > coli); + move_to_nz (fdirc); + return rowj == coli && coli == my_idx; + } + + double + interp_idx (void) const + { + return default_interp_idx(*this); + } + octave_idx_type + flat_idx (void) const + { + return default_flat_idx(*this); + } + bool + skip_ahead (octave_idx_type i) + { + return default_skip_ahead (*this, i); } }; -class perm_iterator +class perm_iterator : public nz_iterator<PermMatrix> { private: - const PermMatrix& mat; + typedef nz_iterator<PermMatrix> Base; + octave_idx_type my_idx; public: perm_iterator (const PermMatrix& arg_mat) : - mat (arg_mat) + nz_iterator<PermMatrix> (arg_mat) { - dir_handler<FORWARD> dirc; - begin (dirc); + begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { - my_idx = dirc.begin (mat.cols ()); + my_idx = dirc.begin (Base::mat.cols ()); } octave_idx_type - interp_idx (void) const - { - return to_interp_idx (row (), col (), mat.dims ()); - } - octave_idx_type col (void) const { return my_idx; @@ -379,8 +529,9 @@ octave_idx_type row (void) const { - return mat.perm_elem (my_idx); + return Base::mat.perm_elem (my_idx); } + bool data (void) const { @@ -396,8 +547,68 @@ bool finished (dir_handler<dir> dirc) const { - return dirc.is_ended (my_idx, mat.rows ()); + return dirc.is_ended (my_idx, Base::mat.rows ()); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type coli) + { + //TODO Check bounds + if (coli < my_idx || rowj < this->row ()) + return false; + my_idx = coli + (rowj > Base::mat.perm_elem (coli)); + return my_idx == coli && rowj == this->row (); + } + + double + interp_idx (void) const + { + return default_interp_idx(*this); + } + octave_idx_type + flat_idx (void) const + { + return default_flat_idx(*this); + } + bool + skip_ahead (octave_idx_type i) + { + return default_skip_ahead (*this, i); } }; +// Uses a one-sided binary search to move to the next element in this column +// whose row is at least rowj. The one-sided binary search guarantees +// O(log(rowj - currentRow)) time to find it. +template<typename T> +void +sparse_iterator<T>::jump_to_row (octave_idx_type rowj) +{ + octave_idx_type ub = Base::mat.cidx (coli + 1); + octave_idx_type lo = my_idx - 1; + octave_idx_type hi = my_idx; + octave_idx_type hidiff = 1; + while (Base::mat.ridx (hi) < rowj) + { + lo = hi; + hidiff *= 2; + hi += hidiff; + if (hi >= ub) + { + hi = ub; + break; + } + } + while (hi - lo > 1) + { + octave_idx_type mid = (lo + hi) / 2; + if (Base::mat.ridx (mid) < rowj) + lo = mid; + else + hi = mid; + } + my_idx = hi; + adjust_col (fdirc); +} + #endif