view liboctave/util/nz-iterators.h @ 19010:3fb030666878 draft default tip dspies

Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior
author David Spies <dnspies@gmail.com>
date Fri, 25 Jul 2014 13:39:31 -0600
parents 8d47ce2053f2
children
line wrap: on
line source

/*

Copyright (C) 2014 David Spies

This file is part of Octave.

Octave is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

Octave is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>.

*/
#if !defined (octave_nz_iterators_h)
#define octave_nz_iterators_h 1

#include <cassert>

#include "array-iter-decl.h"
#include "interp-idx.h"
#include "oct-inttypes.h"
#include "Array.h"
#include "DiagArray2.h"
#include "PermMatrix.h"
#include "Sparse.h"
#include "direction.h"
#include "is-zero.h"

// This file contains generic column-major iterators over
// the nonzero elements of any array or matrix.  If you have a matrix mat
// of type M, you can construct the proper iterator type using
// M::iter_type iter(mat) and iter will iterate efficiently (forwards
// or backwards) over the nonzero elements of mat.
//
// The parameter T is the element-type except for PermMatrix where
// the element type is always bool
//
// Use a dir_handler to indicate which direction you intend to iterate.
// (see step-dir.h).  begin() resets the iterator to the beginning or
// end of the matrix (for dir_handler<FORWARD> and <BACKWARD> respectively).
// finished(dirc) indicates whether the iterators has finished traversing
// the nonzero elements.  step(dirc) steps from one element to the next.
//
// You can, for instance, use a for-loop as follows:
//
// typedef M::iter_type iter_t;
// typedef M::element_type T;
//
// iter_t iter(mat);
// dir_handler<1> dirc;
//
// for(iter.begin (dirc); !iter.finished (dirc); iter.step (dirc))
// {
//   octave_idx_type row = iter.row();
//   octave_idx_type col = iter.col();
//   double doub_index = iter.interp_index ();
//   T elem = iter.data();
//   // ... Do something with these
// }

struct rowcol
{
  rowcol (void) { }

  rowcol (octave_idx_type i, const dim_vector& dims)
  {
#if defined(BOUNDS_CHECKING)
    check_index(i, dims);
#endif
    row = i % dims(0);
    col = i / dims(0);
  }

  rowcol (octave_idx_type rowj, octave_idx_type coli, const dim_vector& dims)
  {
#if defined(BOUNDS_CHECKING)
    check_index(rowj, coli, dims);
#endif
    row = rowj;
    col = coli;
  }

  octave_idx_type row;
  octave_idx_type col;
};

// Default (non-linear) implementation of nz_iterator.interp_idx
template <typename Iter>
inline double
default_interp_idx(const Iter& it)
{
  return to_interp_idx (it.row (), it.col (), it.dims);
}

// Default (non-linear) implementation of nz_iterator.flat_idx
template <typename Iter>
inline octave_idx_type
default_flat_idx (const Iter& it)
{
  return to_flat_idx (it.row (), it.col (), it.dims);
}

// Default (non-linear) implementation of
// nz_iterator.skip_ahead(octave_idx_type)
template <typename Iter>
inline bool
default_skip_ahead (Iter& it, octave_idx_type i)
{
  return it.skip_ahead (i % it.dims(0), i / it.dims(0));
}


template <typename M>
class nz_iterator
{
protected:
  const M mat;

public:
  const dim_vector dims;
  nz_iterator (const M& arg_mat) : mat (arg_mat), dims (arg_mat.dims ()) { }

  // Because of the abundance of templates, this interface cannot be specified
  // using pure virtual methods.  But here's what each method does:
  //
  // begin(dirc): moves the iterator to the beginning or the end
  // finished(dirc): checks whether the iterator has advanced to the end or the
  // beginning
  // row, col, and data: return the current row, column, and element
  // interp_idx and flat_idx: return the current linear index (interp_idx as
  // a double the way the interpreter sees it and flat_idx as an
  // octave_idx_type)
  // step(dirc): steps forward to the next nonzero element
  // skip_ahead(args): skips forward to the nearest nonzero element after
  // (including) args.  Returns true iff the position at args is nonzero
};

// An iterator over full arrays.  When the number of dimensions exceeds
// 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1)
//
// This mimics the behavior of the "find" method (both in Octave and Matlab)
// on many-dimensional matrices.
//
// userc indicates whether to (true) track row and column values or (false)
// compute them on the fly as needed

template<typename T, bool userc>
class array_iterator : public nz_iterator<Array<T> >
{
private:
  typedef nz_iterator<Array<T> > Base;

  const octave_idx_type totcols;
  const octave_idx_type numels;

  rowcol myrc;

  octave_idx_type my_idx;

  template<direction dir>
  void
  step_once (dir_handler<dir> dirc)
  {
    my_idx += dir;
    if (userc)
      {
        myrc.row += dir;
        if (dirc.is_ended (myrc.row, Base::mat.rows ()))
          {
            myrc.row = dirc.begin (Base::mat.rows ());
            myrc.col += dir;
          }
      }
  }

  template<direction dir>
  void
  move_to_nz (dir_handler<dir> dirc)
  {
    while (!finished (dirc) && is_zero (data ()))
      {
        step_once (dirc);
      }
  }

  template <direction dir>
  bool
  next_nz (dir_handler<dir> dirc)
  {
    if (is_zero (data ()))
      {
        step (dirc);
        return false;
      }
    else
      return true;
  }

  bool
  skip_ahead (octave_idx_type i, const rowcol& rc)
  {
    if (i < my_idx)
      return false;
    my_idx = i;
    if (userc)
      myrc = rc;
    return next_nz (fdirc);
  }

public:
  array_iterator (const Array<T>& arg_mat)
      : nz_iterator<Array<T> > (arg_mat), totcols (Base::dims.numel (1)), numels (
          totcols * arg_mat.rows ())
  {
    begin (fdirc);
  }

  template<direction dir>
  void
  begin (dir_handler<dir> dirc)
  {
    if(userc)
      {
        myrc.col = dirc.begin (totcols);
        myrc.row = dirc.begin (Base::mat.rows ());
      }
    my_idx = dirc.begin (Base::mat.numel ());
    move_to_nz (dirc);
  }

  octave_idx_type
  col (void) const
  {
    if (userc)
      return myrc.col;
    else
      return my_idx / Base::mat.rows ();
  }
  octave_idx_type
  row (void) const
  {
    if (userc)
      return myrc.row;
    else
      return my_idx % Base::mat.rows ();
  }
  double
  interp_idx (void) const
  {
    return to_interp_idx (my_idx);
  }
  octave_idx_type
  flat_idx (void) const
  {
    return my_idx;
  }
  T
  data (void) const
  {
    return Base::mat.elem (my_idx);
  }

  template<direction dir>
  void
  step (dir_handler<dir> dirc)
  {
    step_once (dirc);
    move_to_nz (dirc);
  }
  template<direction dir>
  bool
  finished (dir_handler<dir> dirc) const
  {
    return dirc.is_ended (my_idx, numels);
  }

  bool
  skip_ahead (octave_idx_type i)
  {
    return skip_ahead (i, rowcol (i, Base::dims));
  }

  bool
  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
  {
    return skip_ahead (coli * Base::mat.rows () + rowj,
                       rowcol (rowj, coli, Base::dims));
  }

  bool
  skip_ahead (const Array<octave_idx_type>& idxs)
  {
    //TODO Check bounds
    octave_idx_type rowj = idxs(0);
    octave_idx_type coli = 0;
    for(int i = idxs.numel () - 1; i > 1; --i) {
        coli += idxs(i);
        coli *= Base::dims(i);
    }
    coli += idxs(1);
    return skip_ahead (rowj, coli);
  }
};

template<typename T>
class sparse_iterator : public nz_iterator<Sparse<T> >
{
private:
  typedef nz_iterator<Sparse<T> > Base;

  octave_idx_type coli;
  octave_idx_type my_idx;

  template<direction dir>
  void
  adjust_col (dir_handler<dir> dirc)
  {
    while (!finished (dirc)
        && dirc.is_ended (my_idx, Base::mat.cidx (coli), Base::mat.cidx (coli + 1)))
      coli += dir;
  }

  void jump_to_row (octave_idx_type rowj);

public:
  sparse_iterator (const Sparse<T>& arg_mat) :
      nz_iterator<Sparse<T> > (arg_mat)
  {
    begin (fdirc);
  }

  template<direction dir>
  void
  begin (dir_handler<dir> dirc)
  {
    coli = dirc.begin (Base::mat.cols ());
    my_idx = dirc.begin (Base::mat.nnz ());
    adjust_col (dirc);
  }

  octave_idx_type
  col (void) const
  {
    return coli;
  }
  octave_idx_type
  row (void) const
  {
    return Base::mat.ridx (my_idx);
  }

  T
  data (void) const
  {
    return Base::mat.data (my_idx);
  }
  template<direction dir>
  void
  step (dir_handler<dir> dirc)
  {
    my_idx += dir;
    adjust_col (dirc);
  }
  template<direction dir>
  bool
  finished (dir_handler<dir> dirc) const
  {
    return dirc.is_ended (coli, Base::mat.cols ());
  }

  bool
  skip_ahead (octave_idx_type rowj, octave_idx_type arg_coli)
  {
    //TODO Check bounds
    if (arg_coli < coli || (arg_coli == coli && rowj < this->row ()))
      {
        return false;
      }
    else if (arg_coli > coli)
      {
        coli = arg_coli;
        my_idx = Base::mat.cidx (arg_coli);
      }
    jump_to_row (rowj);
    return coli == arg_coli && this->row () == rowj;
  }

  double
  interp_idx (void) const
  {
    return default_interp_idx(*this);
  }
  octave_idx_type
  flat_idx (void) const
  {
    return default_flat_idx(*this);
  }
  bool
  skip_ahead (octave_idx_type i)
  {
    return default_skip_ahead (*this, i);
  }
};

template<typename T>
class diag_iterator : public nz_iterator <DiagArray2<T> >
{
private:
  typedef nz_iterator <DiagArray2<T> > Base;

  octave_idx_type my_idx;

  template <direction dir>
  void
  move_to_nz (dir_handler<dir> dirc)
  {
    while (!finished (dirc) && is_zero (data ()))
      {
        my_idx += dir;
      }
  }

public:
  diag_iterator (const DiagArray2<T>& arg_mat) :
      nz_iterator<DiagArray2<T> > (arg_mat)
  {
    begin (fdirc);
  }

  template<direction dir>
  void
  begin (dir_handler<dir> dirc)
  {
    my_idx = dirc.begin (Base::mat.diag_length ());
    move_to_nz (dirc);
  }

  octave_idx_type
  col (void) const
  {
    return my_idx;
  }
  octave_idx_type
  row (void) const
  {
    return my_idx;
  }

  T
  data (void) const
  {
    return Base::mat.dgelem (my_idx);
  }
  template<direction dir>
  void
  step (dir_handler<dir> dirc)
  {
    my_idx += dir;
    move_to_nz (dirc);
  }
  template<direction dir>
  bool
  finished (dir_handler<dir> dirc) const
  {
    return dirc.is_ended (my_idx, Base::mat.diag_length ());
  }

  bool
  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
  {
    if (coli < my_idx)
      return false;
    my_idx = coli + (rowj > coli);
    move_to_nz (fdirc);
    return rowj == coli && coli == my_idx;
  }

  double
  interp_idx (void) const
  {
    return default_interp_idx(*this);
  }
  octave_idx_type
  flat_idx (void) const
  {
    return default_flat_idx(*this);
  }
  bool
  skip_ahead (octave_idx_type i)
  {
    return default_skip_ahead (*this, i);
  }
};

class perm_iterator : public nz_iterator<PermMatrix>
{
private:
  typedef nz_iterator<PermMatrix> Base;

  octave_idx_type my_idx;

public:
  perm_iterator (const PermMatrix& arg_mat) :
      nz_iterator<PermMatrix> (arg_mat)
  {
    begin (fdirc);
  }

  template<direction dir>
  void
  begin (dir_handler<dir> dirc)
  {
    my_idx = dirc.begin (Base::mat.cols ());
  }

  octave_idx_type
  col (void) const
  {
    return my_idx;
  }
  octave_idx_type
  row (void) const
  {
    return Base::mat.perm_elem (my_idx);
  }

  bool
  data (void) const
  {
    return true;
  }
  template<direction dir>
  void
  step (dir_handler<dir>)
  {
    my_idx += dir;
  }
  template<direction dir>
  bool
  finished (dir_handler<dir> dirc) const
  {
    return dirc.is_ended (my_idx, Base::mat.rows ());
  }

  bool
  skip_ahead (octave_idx_type rowj, octave_idx_type coli)
  {
    //TODO Check bounds
    if (coli < my_idx || rowj < this->row ())
      return false;
    my_idx = coli + (rowj > Base::mat.perm_elem (coli));
    return my_idx == coli && rowj == this->row ();
  }

  double
  interp_idx (void) const
  {
    return default_interp_idx(*this);
  }
  octave_idx_type
  flat_idx (void) const
  {
    return default_flat_idx(*this);
  }
  bool
  skip_ahead (octave_idx_type i)
  {
    return default_skip_ahead (*this, i);
  }
};

// Uses a one-sided binary search to move to the next element in this column
// whose row is at least rowj.  The one-sided binary search guarantees
// O(log(rowj - currentRow)) time to find it.
template<typename T>
void
sparse_iterator<T>::jump_to_row (octave_idx_type rowj)
{
  octave_idx_type ub = Base::mat.cidx (coli + 1);
  octave_idx_type lo = my_idx - 1;
  octave_idx_type hi = my_idx;
  octave_idx_type hidiff = 1;
  while (Base::mat.ridx (hi) < rowj)
    {
      lo = hi;
      hidiff *= 2;
      hi += hidiff;
      if (hi >= ub)
        {
          hi = ub;
          break;
        }
    }
  while (hi - lo > 1)
    {
      octave_idx_type mid = (lo + hi) / 2;
      if (Base::mat.ridx (mid) < rowj)
        lo = mid;
      else
        hi = mid;
    }
  my_idx = hi;
  adjust_col (fdirc);
}

#endif