view liboctave/util/logical-index.h @ 19010:3fb030666878 draft default tip dspies

Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior
author David Spies <dnspies@gmail.com>
date Fri, 25 Jul 2014 13:39:31 -0600
parents
children
line wrap: on
line source

/*

Copyright (C) 2014 David Spies

This file is part of Octave.

Octave is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

Octave is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>.

*/
#if !defined (octave_logical_index_h)
#define octave_logical_index_h 1

#include "Array.h"
#include "dim-vector.h"
#include "dv-utils.h"
#include "direction.h"
#include "dispatch.h"
#include "nz-iterators.h"
#include "ov.h"

//Reshapes idx to have dims_rows rows (adding extra zeros if necessary)
template<typename IM>
Sparse<bool>
partial_sparse_reshape (const IM& idx, octave_idx_type dims_rows)
{
  typename IM::iter_type idx_iter (idx);

  dim_vector res_dims = dim_vector (idx.nnz (), 1);

  Array<octave_idx_type> rows (res_dims);
  Array<octave_idx_type> cols (res_dims);

  const octave_idx_type idx_rows = idx.rows ();

  octave_idx_type i;
  octave_idx_type col_start_col = 0;
  octave_idx_type col_start_row = 0;
  octave_idx_type idx_col = 0;
  for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc);
      ++i, idx_iter.step (fdirc))
    {
      for (; idx_col < idx_iter.col (); ++idx_col)
        {
          octave_idx_type next_row = col_start_row + idx_rows;
          col_start_col += next_row / dims_rows;
          col_start_row = next_row % dims_rows;
        }
      octave_idx_type next_row = col_start_row + idx_iter.row ();
      cols.xelem (i) = col_start_col + next_row / dims_rows;
      rows.xelem (i) = next_row % dims_rows;
    }
  Array<bool> trueScalar (dim_vector (1, 1), true);
  return Sparse<bool> (trueScalar, idx_vector (rows), idx_vector (cols));
}

// Given matrix mat and logical matrix idx, takes mat(idx) and returns the
// result as an Array. "full" indicates if mat is full (ie can use linear index)
// or not. IterM is the nonzero-iterator type for M (the type of mat)
// if !full, mat and idx must have the same height
template<bool full, typename IterM, typename M, typename IM>
Array<typename M::element_type>
take_bool_with_index (const M& mat, const IM& idx)
{
  typedef typename M::element_type ELT_T;

  // After testing over 150 edge-cases in Matlab, this seems to be the rule for
  // logical indexing return-value dimensions
  const dim_vector idx_dims = idx.dims ();
  const dim_vector mat_dims = mat.dims ();
  octave_idx_type idx_nnz = idx.nnz ();
  dim_vector res_dims;
  if (idx_dims(0) == 0 && idx_dims(1) == 0)
    res_dims = dim_vector (0, 0);
  else if (dv_is_scalar (idx_dims) && idx_nnz == 0)
    res_dims = dim_vector (0, 0);
  else if (dv_is_extended_vector (mat_dims))
    res_dims = dv_match_vector (mat_dims, idx_nnz);
  else
    {
      if (dv_is_row (idx_dims))
        res_dims = dim_vector (1, idx_nnz);
      else
        res_dims = dim_vector (idx_nnz, 1);
    }

  IterM mat_iter (mat);
  typename IM::iter_type idx_iter (idx);

  Array<ELT_T> res (res_dims);

  octave_idx_type i;
  for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc);
      ++i, idx_iter.step (fdirc))
    {
      bool nz;
      if (full)
        nz = mat_iter.skip_ahead (idx_iter.flat_idx ());
      else
        nz = mat_iter.skip_ahead (idx_iter.row (), idx_iter.col ());
      if (nz)
        res.xelem (i) = mat_iter.data ();
      else
        res.xelem (i) = static_cast<ELT_T> (0);
    }
  return res;
}

// Determine whether we need to reshape idx before calling take_bool_with_index
// to ensure mat and idx have the same height
template<bool full, typename IterM, typename M, typename IM>
Array<typename M::element_type>
take_bool_index (const M& mat, const IM& idx)
{
  if (full || idx.rows () == mat.rows ())
    return take_bool_with_index<full, IterM> (mat, idx);
  else
    return take_bool_with_index<full, IterM> (
        mat, partial_sparse_reshape (idx, mat.rows ()));
}

template<typename IM>
Array<bool>
bool_index (const PermMatrix& mat, const IM& idx)
{
  return take_bool_index<false, perm_iterator> (mat, idx);
}

template<typename ELT_T, typename IM>
Array<ELT_T>
bool_index (const DiagArray2<ELT_T>& mat, const IM& idx)
{
  return take_bool_index<false, diag_iterator<ELT_T> > (mat, idx);
}

template<typename ELT_T, typename IM>
Array<ELT_T>
bool_index (const Array<ELT_T>& mat, const IM& idx)
{
  return take_bool_index<true, array_iterator<ELT_T, false> > (mat, idx);
}

template<typename ELT_T, typename IM>
Sparse<ELT_T>
bool_index (const Sparse<ELT_T>& mat, const IM& idx)
{
  const Array<ELT_T> res = take_bool_index<false, sparse_iterator<ELT_T> > (
      mat, idx);
  return Sparse<ELT_T> (res);
}


template<typename M>
struct mwrapper
{
  template<typename IM>
  struct idx_caller
  {
    octave_value_list
    operator() (const IM& arg, const M& into)
    {
      octave_value_list res (1);
      res(0) = bool_index (into, arg);
      return res;
    }
  };

  // Dispatch call to bool_index for type of idx.  The nested functor allows
  // for multiple template parameters (since dispatch assumes its template
  // argument has exactly one template parameter)
  static octave_value
  do_call (const M& mat, const octave_value& idx)
  {
    octave_value_list res = dispatch<idx_caller> (idx, mat, "bool_index");
    return res(0);
  }
};

template<typename M>
octave_value
call_bool_index (const M& mat, const octave_value& idx)
{
  return mwrapper<M>::do_call (mat, idx);
}

#endif