Mercurial > octave-dspies
view liboctave/util/logical-index.h @ 19010:3fb030666878 draft default tip dspies
Added special-case logical-indexing function
* logical-index.h (New file) : Logical-indexing function. May be called on
octave_value types via call_bool_index
* nz-iterators.h : Add base-class nz_iterator for iterator types. Array has
template bool for whether to internally store row-col or compute on the fly
Add skip_ahead method which skips forward to the next nonzero after its
argument
Add flat_index for computing octave_idx_type index of current position (with
assertion failure in the case of overflow)
Move is_zero to separate file
* ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc
(do_index_op): Add call to call_bool_index in logical-index.h
* Array.h : Move forward-declaration for array_iterator to separate header file
* dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx)
* array-iter-decl.h (New file): Header file for forward declaration of
array-iterator
* direction.h : Add constants fdirc and bdirc to avoid having to reconstruct
them
* dv-utils.h, dv-utils.cc (New files) :
Utility functions for querying and constructing dim-vectors
* idx-bounds.h (New file) :
Utility constants and functions for determining whether things will overflow
the maximum allowed bounds
* interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear
index of octave_idx_type
* is-zero.h (New file) : Function for determining whether an element is zero
* logical-index.tst : Add tests for correct return-value dimensions and large
sparse matrix behavior
author | David Spies <dnspies@gmail.com> |
---|---|
date | Fri, 25 Jul 2014 13:39:31 -0600 |
parents | |
children |
line wrap: on
line source
/* Copyright (C) 2014 David Spies This file is part of Octave. Octave is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version. Octave is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Octave; see the file COPYING. If not, see <http://www.gnu.org/licenses/>. */ #if !defined (octave_logical_index_h) #define octave_logical_index_h 1 #include "Array.h" #include "dim-vector.h" #include "dv-utils.h" #include "direction.h" #include "dispatch.h" #include "nz-iterators.h" #include "ov.h" //Reshapes idx to have dims_rows rows (adding extra zeros if necessary) template<typename IM> Sparse<bool> partial_sparse_reshape (const IM& idx, octave_idx_type dims_rows) { typename IM::iter_type idx_iter (idx); dim_vector res_dims = dim_vector (idx.nnz (), 1); Array<octave_idx_type> rows (res_dims); Array<octave_idx_type> cols (res_dims); const octave_idx_type idx_rows = idx.rows (); octave_idx_type i; octave_idx_type col_start_col = 0; octave_idx_type col_start_row = 0; octave_idx_type idx_col = 0; for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc); ++i, idx_iter.step (fdirc)) { for (; idx_col < idx_iter.col (); ++idx_col) { octave_idx_type next_row = col_start_row + idx_rows; col_start_col += next_row / dims_rows; col_start_row = next_row % dims_rows; } octave_idx_type next_row = col_start_row + idx_iter.row (); cols.xelem (i) = col_start_col + next_row / dims_rows; rows.xelem (i) = next_row % dims_rows; } Array<bool> trueScalar (dim_vector (1, 1), true); return Sparse<bool> (trueScalar, idx_vector (rows), idx_vector (cols)); } // Given matrix mat and logical matrix idx, takes mat(idx) and returns the // result as an Array. "full" indicates if mat is full (ie can use linear index) // or not. IterM is the nonzero-iterator type for M (the type of mat) // if !full, mat and idx must have the same height template<bool full, typename IterM, typename M, typename IM> Array<typename M::element_type> take_bool_with_index (const M& mat, const IM& idx) { typedef typename M::element_type ELT_T; // After testing over 150 edge-cases in Matlab, this seems to be the rule for // logical indexing return-value dimensions const dim_vector idx_dims = idx.dims (); const dim_vector mat_dims = mat.dims (); octave_idx_type idx_nnz = idx.nnz (); dim_vector res_dims; if (idx_dims(0) == 0 && idx_dims(1) == 0) res_dims = dim_vector (0, 0); else if (dv_is_scalar (idx_dims) && idx_nnz == 0) res_dims = dim_vector (0, 0); else if (dv_is_extended_vector (mat_dims)) res_dims = dv_match_vector (mat_dims, idx_nnz); else { if (dv_is_row (idx_dims)) res_dims = dim_vector (1, idx_nnz); else res_dims = dim_vector (idx_nnz, 1); } IterM mat_iter (mat); typename IM::iter_type idx_iter (idx); Array<ELT_T> res (res_dims); octave_idx_type i; for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc); ++i, idx_iter.step (fdirc)) { bool nz; if (full) nz = mat_iter.skip_ahead (idx_iter.flat_idx ()); else nz = mat_iter.skip_ahead (idx_iter.row (), idx_iter.col ()); if (nz) res.xelem (i) = mat_iter.data (); else res.xelem (i) = static_cast<ELT_T> (0); } return res; } // Determine whether we need to reshape idx before calling take_bool_with_index // to ensure mat and idx have the same height template<bool full, typename IterM, typename M, typename IM> Array<typename M::element_type> take_bool_index (const M& mat, const IM& idx) { if (full || idx.rows () == mat.rows ()) return take_bool_with_index<full, IterM> (mat, idx); else return take_bool_with_index<full, IterM> ( mat, partial_sparse_reshape (idx, mat.rows ())); } template<typename IM> Array<bool> bool_index (const PermMatrix& mat, const IM& idx) { return take_bool_index<false, perm_iterator> (mat, idx); } template<typename ELT_T, typename IM> Array<ELT_T> bool_index (const DiagArray2<ELT_T>& mat, const IM& idx) { return take_bool_index<false, diag_iterator<ELT_T> > (mat, idx); } template<typename ELT_T, typename IM> Array<ELT_T> bool_index (const Array<ELT_T>& mat, const IM& idx) { return take_bool_index<true, array_iterator<ELT_T, false> > (mat, idx); } template<typename ELT_T, typename IM> Sparse<ELT_T> bool_index (const Sparse<ELT_T>& mat, const IM& idx) { const Array<ELT_T> res = take_bool_index<false, sparse_iterator<ELT_T> > ( mat, idx); return Sparse<ELT_T> (res); } template<typename M> struct mwrapper { template<typename IM> struct idx_caller { octave_value_list operator() (const IM& arg, const M& into) { octave_value_list res (1); res(0) = bool_index (into, arg); return res; } }; // Dispatch call to bool_index for type of idx. The nested functor allows // for multiple template parameters (since dispatch assumes its template // argument has exactly one template parameter) static octave_value do_call (const M& mat, const octave_value& idx) { octave_value_list res = dispatch<idx_caller> (idx, mat, "bool_index"); return res(0); } }; template<typename M> octave_value call_bool_index (const M& mat, const octave_value& idx) { return mwrapper<M>::do_call (mat, idx); } #endif