Mercurial > octave-dspies
view liboctave/util/nz-iterators.h @ 19010:3fb030666878 draft default tip dspies
Added special-case logical-indexing function
* logical-index.h (New file) : Logical-indexing function. May be called on
octave_value types via call_bool_index
* nz-iterators.h : Add base-class nz_iterator for iterator types. Array has
template bool for whether to internally store row-col or compute on the fly
Add skip_ahead method which skips forward to the next nonzero after its
argument
Add flat_index for computing octave_idx_type index of current position (with
assertion failure in the case of overflow)
Move is_zero to separate file
* ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc
(do_index_op): Add call to call_bool_index in logical-index.h
* Array.h : Move forward-declaration for array_iterator to separate header file
* dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx)
* array-iter-decl.h (New file): Header file for forward declaration of
array-iterator
* direction.h : Add constants fdirc and bdirc to avoid having to reconstruct
them
* dv-utils.h, dv-utils.cc (New files) :
Utility functions for querying and constructing dim-vectors
* idx-bounds.h (New file) :
Utility constants and functions for determining whether things will overflow
the maximum allowed bounds
* interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear
index of octave_idx_type
* is-zero.h (New file) : Function for determining whether an element is zero
* logical-index.tst : Add tests for correct return-value dimensions and large
sparse matrix behavior
author | David Spies <dnspies@gmail.com> |
---|---|
date | Fri, 25 Jul 2014 13:39:31 -0600 |
parents | 8d47ce2053f2 |
children |
line wrap: on
line source
/* Copyright (C) 2014 David Spies This file is part of Octave. Octave is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version. Octave is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Octave; see the file COPYING. If not, see <http://www.gnu.org/licenses/>. */ #if !defined (octave_nz_iterators_h) #define octave_nz_iterators_h 1 #include <cassert> #include "array-iter-decl.h" #include "interp-idx.h" #include "oct-inttypes.h" #include "Array.h" #include "DiagArray2.h" #include "PermMatrix.h" #include "Sparse.h" #include "direction.h" #include "is-zero.h" // This file contains generic column-major iterators over // the nonzero elements of any array or matrix. If you have a matrix mat // of type M, you can construct the proper iterator type using // M::iter_type iter(mat) and iter will iterate efficiently (forwards // or backwards) over the nonzero elements of mat. // // The parameter T is the element-type except for PermMatrix where // the element type is always bool // // Use a dir_handler to indicate which direction you intend to iterate. // (see step-dir.h). begin() resets the iterator to the beginning or // end of the matrix (for dir_handler<FORWARD> and <BACKWARD> respectively). // finished(dirc) indicates whether the iterators has finished traversing // the nonzero elements. step(dirc) steps from one element to the next. // // You can, for instance, use a for-loop as follows: // // typedef M::iter_type iter_t; // typedef M::element_type T; // // iter_t iter(mat); // dir_handler<1> dirc; // // for(iter.begin (dirc); !iter.finished (dirc); iter.step (dirc)) // { // octave_idx_type row = iter.row(); // octave_idx_type col = iter.col(); // double doub_index = iter.interp_index (); // T elem = iter.data(); // // ... Do something with these // } struct rowcol { rowcol (void) { } rowcol (octave_idx_type i, const dim_vector& dims) { #if defined(BOUNDS_CHECKING) check_index(i, dims); #endif row = i % dims(0); col = i / dims(0); } rowcol (octave_idx_type rowj, octave_idx_type coli, const dim_vector& dims) { #if defined(BOUNDS_CHECKING) check_index(rowj, coli, dims); #endif row = rowj; col = coli; } octave_idx_type row; octave_idx_type col; }; // Default (non-linear) implementation of nz_iterator.interp_idx template <typename Iter> inline double default_interp_idx(const Iter& it) { return to_interp_idx (it.row (), it.col (), it.dims); } // Default (non-linear) implementation of nz_iterator.flat_idx template <typename Iter> inline octave_idx_type default_flat_idx (const Iter& it) { return to_flat_idx (it.row (), it.col (), it.dims); } // Default (non-linear) implementation of // nz_iterator.skip_ahead(octave_idx_type) template <typename Iter> inline bool default_skip_ahead (Iter& it, octave_idx_type i) { return it.skip_ahead (i % it.dims(0), i / it.dims(0)); } template <typename M> class nz_iterator { protected: const M mat; public: const dim_vector dims; nz_iterator (const M& arg_mat) : mat (arg_mat), dims (arg_mat.dims ()) { } // Because of the abundance of templates, this interface cannot be specified // using pure virtual methods. But here's what each method does: // // begin(dirc): moves the iterator to the beginning or the end // finished(dirc): checks whether the iterator has advanced to the end or the // beginning // row, col, and data: return the current row, column, and element // interp_idx and flat_idx: return the current linear index (interp_idx as // a double the way the interpreter sees it and flat_idx as an // octave_idx_type) // step(dirc): steps forward to the next nonzero element // skip_ahead(args): skips forward to the nearest nonzero element after // (including) args. Returns true iff the position at args is nonzero }; // An iterator over full arrays. When the number of dimensions exceeds // 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1) // // This mimics the behavior of the "find" method (both in Octave and Matlab) // on many-dimensional matrices. // // userc indicates whether to (true) track row and column values or (false) // compute them on the fly as needed template<typename T, bool userc> class array_iterator : public nz_iterator<Array<T> > { private: typedef nz_iterator<Array<T> > Base; const octave_idx_type totcols; const octave_idx_type numels; rowcol myrc; octave_idx_type my_idx; template<direction dir> void step_once (dir_handler<dir> dirc) { my_idx += dir; if (userc) { myrc.row += dir; if (dirc.is_ended (myrc.row, Base::mat.rows ())) { myrc.row = dirc.begin (Base::mat.rows ()); myrc.col += dir; } } } template<direction dir> void move_to_nz (dir_handler<dir> dirc) { while (!finished (dirc) && is_zero (data ())) { step_once (dirc); } } template <direction dir> bool next_nz (dir_handler<dir> dirc) { if (is_zero (data ())) { step (dirc); return false; } else return true; } bool skip_ahead (octave_idx_type i, const rowcol& rc) { if (i < my_idx) return false; my_idx = i; if (userc) myrc = rc; return next_nz (fdirc); } public: array_iterator (const Array<T>& arg_mat) : nz_iterator<Array<T> > (arg_mat), totcols (Base::dims.numel (1)), numels ( totcols * arg_mat.rows ()) { begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { if(userc) { myrc.col = dirc.begin (totcols); myrc.row = dirc.begin (Base::mat.rows ()); } my_idx = dirc.begin (Base::mat.numel ()); move_to_nz (dirc); } octave_idx_type col (void) const { if (userc) return myrc.col; else return my_idx / Base::mat.rows (); } octave_idx_type row (void) const { if (userc) return myrc.row; else return my_idx % Base::mat.rows (); } double interp_idx (void) const { return to_interp_idx (my_idx); } octave_idx_type flat_idx (void) const { return my_idx; } T data (void) const { return Base::mat.elem (my_idx); } template<direction dir> void step (dir_handler<dir> dirc) { step_once (dirc); move_to_nz (dirc); } template<direction dir> bool finished (dir_handler<dir> dirc) const { return dirc.is_ended (my_idx, numels); } bool skip_ahead (octave_idx_type i) { return skip_ahead (i, rowcol (i, Base::dims)); } bool skip_ahead (octave_idx_type rowj, octave_idx_type coli) { return skip_ahead (coli * Base::mat.rows () + rowj, rowcol (rowj, coli, Base::dims)); } bool skip_ahead (const Array<octave_idx_type>& idxs) { //TODO Check bounds octave_idx_type rowj = idxs(0); octave_idx_type coli = 0; for(int i = idxs.numel () - 1; i > 1; --i) { coli += idxs(i); coli *= Base::dims(i); } coli += idxs(1); return skip_ahead (rowj, coli); } }; template<typename T> class sparse_iterator : public nz_iterator<Sparse<T> > { private: typedef nz_iterator<Sparse<T> > Base; octave_idx_type coli; octave_idx_type my_idx; template<direction dir> void adjust_col (dir_handler<dir> dirc) { while (!finished (dirc) && dirc.is_ended (my_idx, Base::mat.cidx (coli), Base::mat.cidx (coli + 1))) coli += dir; } void jump_to_row (octave_idx_type rowj); public: sparse_iterator (const Sparse<T>& arg_mat) : nz_iterator<Sparse<T> > (arg_mat) { begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { coli = dirc.begin (Base::mat.cols ()); my_idx = dirc.begin (Base::mat.nnz ()); adjust_col (dirc); } octave_idx_type col (void) const { return coli; } octave_idx_type row (void) const { return Base::mat.ridx (my_idx); } T data (void) const { return Base::mat.data (my_idx); } template<direction dir> void step (dir_handler<dir> dirc) { my_idx += dir; adjust_col (dirc); } template<direction dir> bool finished (dir_handler<dir> dirc) const { return dirc.is_ended (coli, Base::mat.cols ()); } bool skip_ahead (octave_idx_type rowj, octave_idx_type arg_coli) { //TODO Check bounds if (arg_coli < coli || (arg_coli == coli && rowj < this->row ())) { return false; } else if (arg_coli > coli) { coli = arg_coli; my_idx = Base::mat.cidx (arg_coli); } jump_to_row (rowj); return coli == arg_coli && this->row () == rowj; } double interp_idx (void) const { return default_interp_idx(*this); } octave_idx_type flat_idx (void) const { return default_flat_idx(*this); } bool skip_ahead (octave_idx_type i) { return default_skip_ahead (*this, i); } }; template<typename T> class diag_iterator : public nz_iterator <DiagArray2<T> > { private: typedef nz_iterator <DiagArray2<T> > Base; octave_idx_type my_idx; template <direction dir> void move_to_nz (dir_handler<dir> dirc) { while (!finished (dirc) && is_zero (data ())) { my_idx += dir; } } public: diag_iterator (const DiagArray2<T>& arg_mat) : nz_iterator<DiagArray2<T> > (arg_mat) { begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { my_idx = dirc.begin (Base::mat.diag_length ()); move_to_nz (dirc); } octave_idx_type col (void) const { return my_idx; } octave_idx_type row (void) const { return my_idx; } T data (void) const { return Base::mat.dgelem (my_idx); } template<direction dir> void step (dir_handler<dir> dirc) { my_idx += dir; move_to_nz (dirc); } template<direction dir> bool finished (dir_handler<dir> dirc) const { return dirc.is_ended (my_idx, Base::mat.diag_length ()); } bool skip_ahead (octave_idx_type rowj, octave_idx_type coli) { if (coli < my_idx) return false; my_idx = coli + (rowj > coli); move_to_nz (fdirc); return rowj == coli && coli == my_idx; } double interp_idx (void) const { return default_interp_idx(*this); } octave_idx_type flat_idx (void) const { return default_flat_idx(*this); } bool skip_ahead (octave_idx_type i) { return default_skip_ahead (*this, i); } }; class perm_iterator : public nz_iterator<PermMatrix> { private: typedef nz_iterator<PermMatrix> Base; octave_idx_type my_idx; public: perm_iterator (const PermMatrix& arg_mat) : nz_iterator<PermMatrix> (arg_mat) { begin (fdirc); } template<direction dir> void begin (dir_handler<dir> dirc) { my_idx = dirc.begin (Base::mat.cols ()); } octave_idx_type col (void) const { return my_idx; } octave_idx_type row (void) const { return Base::mat.perm_elem (my_idx); } bool data (void) const { return true; } template<direction dir> void step (dir_handler<dir>) { my_idx += dir; } template<direction dir> bool finished (dir_handler<dir> dirc) const { return dirc.is_ended (my_idx, Base::mat.rows ()); } bool skip_ahead (octave_idx_type rowj, octave_idx_type coli) { //TODO Check bounds if (coli < my_idx || rowj < this->row ()) return false; my_idx = coli + (rowj > Base::mat.perm_elem (coli)); return my_idx == coli && rowj == this->row (); } double interp_idx (void) const { return default_interp_idx(*this); } octave_idx_type flat_idx (void) const { return default_flat_idx(*this); } bool skip_ahead (octave_idx_type i) { return default_skip_ahead (*this, i); } }; // Uses a one-sided binary search to move to the next element in this column // whose row is at least rowj. The one-sided binary search guarantees // O(log(rowj - currentRow)) time to find it. template<typename T> void sparse_iterator<T>::jump_to_row (octave_idx_type rowj) { octave_idx_type ub = Base::mat.cidx (coli + 1); octave_idx_type lo = my_idx - 1; octave_idx_type hi = my_idx; octave_idx_type hidiff = 1; while (Base::mat.ridx (hi) < rowj) { lo = hi; hidiff *= 2; hi += hidiff; if (hi >= ub) { hi = ub; break; } } while (hi - lo > 1) { octave_idx_type mid = (lo + hi) / 2; if (Base::mat.ridx (mid) < rowj) lo = mid; else hi = mid; } my_idx = hi; adjust_col (fdirc); } #endif