# HG changeset patch # User David Spies # Date 1406317171 21600 # Node ID 3fb0306668782ef7fcfb72afdc2fe76937cd549b # Parent 8d47ce2053f257c446e8b32c5d5b33e74e896632 Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior diff -r 8d47ce2053f2 -r 3fb030666878 libinterp/octave-value/ov-base-diag.cc --- a/libinterp/octave-value/ov-base-diag.cc Mon Jul 14 13:07:59 2014 -0600 +++ b/libinterp/octave-value/ov-base-diag.cc Fri Jul 25 13:39:31 2014 -0600 @@ -37,6 +37,7 @@ #include "gripes.h" #include "oct-stream.h" #include "ops.h" +#include "logical-index.h" #include "ls-oct-ascii.h" @@ -125,6 +126,8 @@ retval = to_dense ().do_index_op (idx, resize_ok); } } + else if (idx.length () == 1 && idx(0).is_bool_type ()) + retval = call_bool_index (matrix, idx(0)); else retval = to_dense ().do_index_op (idx, resize_ok); diff -r 8d47ce2053f2 -r 3fb030666878 libinterp/octave-value/ov-base-mat.cc --- a/libinterp/octave-value/ov-base-mat.cc Mon Jul 14 13:07:59 2014 -0600 +++ b/libinterp/octave-value/ov-base-mat.cc Fri Jul 25 13:39:31 2014 -0600 @@ -34,6 +34,7 @@ #include "ov-base-mat.h" #include "ov-base-scalar.h" #include "pr-output.h" +#include "logical-index.h" template octave_value @@ -146,15 +147,20 @@ case 1: { - idx_vector i = idx (0).index_vector (); - - if (! error_state) + if (idx(0).is_bool_type ()) + retval = call_bool_index (matrix, idx(0)); + else { - // optimize single scalar index. - if (! resize_ok && i.is_scalar ()) - retval = cmatrix.checkelem (i(0)); - else - retval = MT (matrix.index (i, resize_ok)); + idx_vector i = idx (0).index_vector (); + + if (! error_state) + { + // optimize single scalar index. + if (! resize_ok && i.is_scalar ()) + retval = cmatrix.checkelem (i(0)); + else + retval = MT (matrix.index (i, resize_ok)); + } } } break; diff -r 8d47ce2053f2 -r 3fb030666878 libinterp/octave-value/ov-base-sparse.cc --- a/libinterp/octave-value/ov-base-sparse.cc Mon Jul 14 13:07:59 2014 -0600 +++ b/libinterp/octave-value/ov-base-sparse.cc Fri Jul 25 13:39:31 2014 -0600 @@ -29,6 +29,7 @@ #include #include +#include "logical-index.h" #include "oct-obj.h" #include "ov-base.h" #include "quit.h" @@ -61,10 +62,16 @@ case 1: { - idx_vector i = idx (0).index_vector (); + const octave_value& v = idx(0); + if (v.is_bool_type ()) + retval = call_bool_index (matrix, v); + else + { + idx_vector i = v.index_vector (); - if (! error_state) - retval = octave_value (matrix.index (i, resize_ok)); + if (!error_state) + retval = octave_value (matrix.index (i, resize_ok)); + } } break; diff -r 8d47ce2053f2 -r 3fb030666878 libinterp/octave-value/ov-perm.cc --- a/libinterp/octave-value/ov-perm.cc Mon Jul 14 13:07:59 2014 -0600 +++ b/libinterp/octave-value/ov-perm.cc Fri Jul 25 13:39:31 2014 -0600 @@ -35,6 +35,7 @@ #include "gripes.h" #include "ops.h" #include "pr-output.h" +#include "logical-index.h" #include "ls-oct-ascii.h" @@ -117,6 +118,8 @@ { retval = matrix.checkelem (idx0(0), idx1(0)); } + else if (nidx == 1 && idx(0).is_bool_type ()) + retval = call_bool_index (matrix, idx(0)); else retval = to_dense ().do_index_op (idx, resize_ok); } diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/array/Array.h --- a/liboctave/array/Array.h Mon Jul 14 13:07:59 2014 -0600 +++ b/liboctave/array/Array.h Fri Jul 25 13:39:31 2014 -0600 @@ -41,11 +41,7 @@ #include "quit.h" #include "oct-mem.h" #include "oct-refcount.h" - -//Forward declaration for array_iterator, -//the nonzero-iterator type for Array (in nz_iterator.h) -template -class array_iterator; +#include "array-iter-decl.h" // One dimensional array class. Handles the reference counting for // all the derived classes. diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/array/dim-vector.cc --- a/liboctave/array/dim-vector.cc Mon Jul 14 13:07:59 2014 -0600 +++ b/liboctave/array/dim-vector.cc Fri Jul 25 13:39:31 2014 -0600 @@ -27,6 +27,7 @@ #include +#include "idx-bounds.h" #include "dim-vector.h" // The maximum allowed value for a dimension extent. This will normally be a @@ -36,7 +37,7 @@ octave_idx_type dim_vector::dim_max (void) { - return std::numeric_limits::max () - 1; + return max_idx; } void diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/array-iter-decl.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/array-iter-decl.h Fri Jul 25 13:39:31 2014 -0600 @@ -0,0 +1,33 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#if !defined (octave_array_iter_decl_h) +#define octave_array_iter_decl_h 1 + +//Forward declaration for array_iterator, +//the nonzero-iterator type for Array (in nz_iterator.h) +//Must be declared in its own file because one of the template parameters +//has a default value. +template +class array_iterator; + +#endif diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/direction.h --- a/liboctave/util/direction.h Mon Jul 14 13:07:59 2014 -0600 +++ b/liboctave/util/direction.h Fri Jul 25 13:39:31 2014 -0600 @@ -73,4 +73,7 @@ } }; +extern const dir_handler fdirc; +extern const dir_handler bdirc; + #endif diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/dv-utils.cc --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/dv-utils.cc Fri Jul 25 13:39:31 2014 -0600 @@ -0,0 +1,85 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "dim-vector.h" +#include + +bool +dv_is_extended_vector (const dim_vector& dv) +{ + bool found = false; + for (int i = 0; i < dv.length (); i++) + { + if (dv(i) != 1) + { + if (found) + return false; + else + found = true; + } + } + return found; +} + +bool +dv_is_scalar (const dim_vector& dv) +{ + for (int i = 0; i < dv.length (); i++) + { + if (dv(i) != 1) + return false; + } + return true; +} + +bool +dv_is_row (const dim_vector& dv) +{ + return dv_is_extended_vector (dv) && dv(1) != 1; +} + +//Returns the dimension along which a vector is a nonsingleton +int +dv_vector_dimension (const dim_vector& dv) +{ + assert(dv_is_extended_vector (dv)); + for (int i = 0; i < dv.length (); i++) + { + if (dv(i) != 1) + return i; + } + liboctave_fatal ("No non-1 dimension"); +} + +dim_vector +dv_match_vector (const dim_vector& dv, octave_idx_type numel) +{ + if(numel == 1) + return dim_vector (1, 1); + dim_vector res = dv; + res (dv_vector_dimension (dv)) = numel; + return res; +} diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/dv-utils.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/dv-utils.h Fri Jul 25 13:39:31 2014 -0600 @@ -0,0 +1,40 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#if !defined (octave_dv_utils_h) +#define octave_dv_utils_h 1 + +class dim_vector; + +//Returns true if the argument has exactly one nonsingleton dimension +bool dv_is_extended_vector(const dim_vector& dv); + +//Returns true if all of dv's dimension-lengths are 1 +bool dv_is_scalar(const dim_vector& dv); + +//Returns true if the argument's second dimension is the only nonsingleton +bool dv_is_row(const dim_vector& dv); + +//Creates a vector of length numel whose nonsingleton dimension is the same as dv +dim_vector dv_match_vector(const dim_vector& dv, octave_idx_type numel); + +#endif diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/idx-bounds.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/idx-bounds.h Fri Jul 25 13:39:31 2014 -0600 @@ -0,0 +1,53 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#if !defined (octave_idx_bounds_h) +#define octave_idx_bounds_h 1 + +#include + +//The maximum value for octave_idx_type (and the maximum possible size of an +//Array) +const octave_idx_type idx_type_max_value = + std::numeric_limits::max (); + +//The maximum index an array can have (idx_type_max_value - 1) +const octave_idx_type max_idx = idx_type_max_value - 1; + +//Determines if a row-col pair will overflow octave_idx_type if translated into +//a linear index +inline bool +row_col_to_idx_overflows (octave_idx_type row, octave_idx_type col, + octave_idx_type height) +{ + return row > (max_idx - col) / height; +} + +//Determines if a height-width pair will overflow octave_idx_type when nelem is +//extracted +inline bool +height_width_to_nelem_overflows (octave_idx_type height, octave_idx_type width) +{ + return width > idx_type_max_value / height; +} + +#endif diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/interp-idx.h --- a/liboctave/util/interp-idx.h Mon Jul 14 13:07:59 2014 -0600 +++ b/liboctave/util/interp-idx.h Fri Jul 25 13:39:31 2014 -0600 @@ -22,6 +22,9 @@ #if !defined (octave_interp_idx_h) #define octave_interp_idx_h 1 +#include "idx-bounds.h" +#include "Array-util.h" + // Simple method for converting between C++ octave_idx_type and // Octave data-types (convert to double and add one). inline double @@ -44,4 +47,13 @@ return col * static_cast (dims(0)) + row + 1; } +//Converts a row-column pair to a "flat" linear index. +//Asserts that its result won't overflow octave_idx_type +inline octave_idx_type +to_flat_idx (octave_idx_type row, octave_idx_type col, const dim_vector& dims) +{ + assert(!row_col_to_idx_overflows (row, col, dims(0))); + return compute_index (row, col, dims, false); +} + #endif diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/is-zero.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/is-zero.h Fri Jul 25 13:39:31 2014 -0600 @@ -0,0 +1,34 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ +#if !defined (octave_is_zero_h) +#define octave_is_zero_h + +// A generic method for checking if some element of a matrix with +// element type T is zero. +template +inline bool +is_zero (T t) +{ + return t == static_cast (0); +} + +#endif diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/logical-index.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/liboctave/util/logical-index.h Fri Jul 25 13:39:31 2014 -0600 @@ -0,0 +1,197 @@ +/* + +Copyright (C) 2014 David Spies + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ +#if !defined (octave_logical_index_h) +#define octave_logical_index_h 1 + +#include "Array.h" +#include "dim-vector.h" +#include "dv-utils.h" +#include "direction.h" +#include "dispatch.h" +#include "nz-iterators.h" +#include "ov.h" + +//Reshapes idx to have dims_rows rows (adding extra zeros if necessary) +template +Sparse +partial_sparse_reshape (const IM& idx, octave_idx_type dims_rows) +{ + typename IM::iter_type idx_iter (idx); + + dim_vector res_dims = dim_vector (idx.nnz (), 1); + + Array rows (res_dims); + Array cols (res_dims); + + const octave_idx_type idx_rows = idx.rows (); + + octave_idx_type i; + octave_idx_type col_start_col = 0; + octave_idx_type col_start_row = 0; + octave_idx_type idx_col = 0; + for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc); + ++i, idx_iter.step (fdirc)) + { + for (; idx_col < idx_iter.col (); ++idx_col) + { + octave_idx_type next_row = col_start_row + idx_rows; + col_start_col += next_row / dims_rows; + col_start_row = next_row % dims_rows; + } + octave_idx_type next_row = col_start_row + idx_iter.row (); + cols.xelem (i) = col_start_col + next_row / dims_rows; + rows.xelem (i) = next_row % dims_rows; + } + Array trueScalar (dim_vector (1, 1), true); + return Sparse (trueScalar, idx_vector (rows), idx_vector (cols)); +} + +// Given matrix mat and logical matrix idx, takes mat(idx) and returns the +// result as an Array. "full" indicates if mat is full (ie can use linear index) +// or not. IterM is the nonzero-iterator type for M (the type of mat) +// if !full, mat and idx must have the same height +template +Array +take_bool_with_index (const M& mat, const IM& idx) +{ + typedef typename M::element_type ELT_T; + + // After testing over 150 edge-cases in Matlab, this seems to be the rule for + // logical indexing return-value dimensions + const dim_vector idx_dims = idx.dims (); + const dim_vector mat_dims = mat.dims (); + octave_idx_type idx_nnz = idx.nnz (); + dim_vector res_dims; + if (idx_dims(0) == 0 && idx_dims(1) == 0) + res_dims = dim_vector (0, 0); + else if (dv_is_scalar (idx_dims) && idx_nnz == 0) + res_dims = dim_vector (0, 0); + else if (dv_is_extended_vector (mat_dims)) + res_dims = dv_match_vector (mat_dims, idx_nnz); + else + { + if (dv_is_row (idx_dims)) + res_dims = dim_vector (1, idx_nnz); + else + res_dims = dim_vector (idx_nnz, 1); + } + + IterM mat_iter (mat); + typename IM::iter_type idx_iter (idx); + + Array res (res_dims); + + octave_idx_type i; + for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc); + ++i, idx_iter.step (fdirc)) + { + bool nz; + if (full) + nz = mat_iter.skip_ahead (idx_iter.flat_idx ()); + else + nz = mat_iter.skip_ahead (idx_iter.row (), idx_iter.col ()); + if (nz) + res.xelem (i) = mat_iter.data (); + else + res.xelem (i) = static_cast (0); + } + return res; +} + +// Determine whether we need to reshape idx before calling take_bool_with_index +// to ensure mat and idx have the same height +template +Array +take_bool_index (const M& mat, const IM& idx) +{ + if (full || idx.rows () == mat.rows ()) + return take_bool_with_index (mat, idx); + else + return take_bool_with_index ( + mat, partial_sparse_reshape (idx, mat.rows ())); +} + +template +Array +bool_index (const PermMatrix& mat, const IM& idx) +{ + return take_bool_index (mat, idx); +} + +template +Array +bool_index (const DiagArray2& mat, const IM& idx) +{ + return take_bool_index > (mat, idx); +} + +template +Array +bool_index (const Array& mat, const IM& idx) +{ + return take_bool_index > (mat, idx); +} + +template +Sparse +bool_index (const Sparse& mat, const IM& idx) +{ + const Array res = take_bool_index > ( + mat, idx); + return Sparse (res); +} + + +template +struct mwrapper +{ + template + struct idx_caller + { + octave_value_list + operator() (const IM& arg, const M& into) + { + octave_value_list res (1); + res(0) = bool_index (into, arg); + return res; + } + }; + + // Dispatch call to bool_index for type of idx. The nested functor allows + // for multiple template parameters (since dispatch assumes its template + // argument has exactly one template parameter) + static octave_value + do_call (const M& mat, const octave_value& idx) + { + octave_value_list res = dispatch (idx, mat, "bool_index"); + return res(0); + } +}; + +template +octave_value +call_bool_index (const M& mat, const octave_value& idx) +{ + return mwrapper::do_call (mat, idx); +} + +#endif diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/module.mk --- a/liboctave/util/module.mk Mon Jul 14 13:07:59 2014 -0600 +++ b/liboctave/util/module.mk Fri Jul 25 13:39:31 2014 -0600 @@ -3,6 +3,7 @@ UTIL_INC = \ util/action-container.h \ + util/array-iter-decl.h \ util/base-list.h \ util/byte-swap.h \ util/caseless-str.h \ @@ -10,12 +11,16 @@ util/cmd-hist.h \ util/data-conv.h \ util/direction.h \ + util/dv-utils.h \ util/find.h \ util/functor.h \ util/glob-match.h \ + util/idx-bounds.h \ util/interp-idx.h \ + util/is-zero.h \ util/lo-array-gripes.h \ util/lo-cutils.h \ + util/logical-index.h \ util/lo-ieee.h \ util/lo-macros.h \ util/lo-math.h \ @@ -60,6 +65,7 @@ util/cmd-edit.cc \ util/cmd-hist.cc \ util/data-conv.cc \ + util/dv-utils.cc \ util/glob-match.cc \ util/lo-array-gripes.cc \ util/lo-ieee.cc \ diff -r 8d47ce2053f2 -r 3fb030666878 liboctave/util/nz-iterators.h --- a/liboctave/util/nz-iterators.h Mon Jul 14 13:07:59 2014 -0600 +++ b/liboctave/util/nz-iterators.h Fri Jul 25 13:39:31 2014 -0600 @@ -22,6 +22,9 @@ #if !defined (octave_nz_iterators_h) #define octave_nz_iterators_h 1 +#include + +#include "array-iter-decl.h" #include "interp-idx.h" #include "oct-inttypes.h" #include "Array.h" @@ -29,6 +32,7 @@ #include "PermMatrix.h" #include "Sparse.h" #include "direction.h" +#include "is-zero.h" // This file contains generic column-major iterators over // the nonzero elements of any array or matrix. If you have a matrix mat @@ -61,67 +65,104 @@ // T elem = iter.data(); // // ... Do something with these // } -// -// Note that array_iter for indexing over full matrices also includes -// a iter.flat_index () method which returns an octave_idx_type. -// -// The other iterators to not have a flat_index() method because they -// risk overflowing octave_idx_type. It is recommended you take care -// to implement your function in a way that accounts for this problem. -// -// FIXME: I'd like to add in these -// default no-parameter versions of -// begin() and step() to each of the -// classes. But the C++ compiler complains -// because apparently I'm not allowed to overload -// templated methods with non-templated ones. Any -// ideas for work-arounds? -// -//#define INCLUDE_DEFAULT_STEPS \ -// void begin (void) \ -// { \ -// dir_handler dirc; \ -// begin (dirc); \ -// } \ -// void step (void) \ -// { \ -// dir_handler dirc; \ -// step (dirc); \ -// } \ -// bool finished (void) const \ -// { \ -// dir_handler dirc; \ -// return finished (dirc); \ -// } + +struct rowcol +{ + rowcol (void) { } + + rowcol (octave_idx_type i, const dim_vector& dims) + { +#if defined(BOUNDS_CHECKING) + check_index(i, dims); +#endif + row = i % dims(0); + col = i / dims(0); + } + + rowcol (octave_idx_type rowj, octave_idx_type coli, const dim_vector& dims) + { +#if defined(BOUNDS_CHECKING) + check_index(rowj, coli, dims); +#endif + row = rowj; + col = coli; + } + + octave_idx_type row; + octave_idx_type col; +}; + +// Default (non-linear) implementation of nz_iterator.interp_idx +template +inline double +default_interp_idx(const Iter& it) +{ + return to_interp_idx (it.row (), it.col (), it.dims); +} -// A generic method for checking if some element of a matrix with -// element type T is zero. -template -bool -is_zero (T t) +// Default (non-linear) implementation of nz_iterator.flat_idx +template +inline octave_idx_type +default_flat_idx (const Iter& it) +{ + return to_flat_idx (it.row (), it.col (), it.dims); +} + +// Default (non-linear) implementation of +// nz_iterator.skip_ahead(octave_idx_type) +template +inline bool +default_skip_ahead (Iter& it, octave_idx_type i) +{ + return it.skip_ahead (i % it.dims(0), i / it.dims(0)); +} + + +template +class nz_iterator { - return t == static_cast (0); -} +protected: + const M mat; + +public: + const dim_vector dims; + nz_iterator (const M& arg_mat) : mat (arg_mat), dims (arg_mat.dims ()) { } + + // Because of the abundance of templates, this interface cannot be specified + // using pure virtual methods. But here's what each method does: + // + // begin(dirc): moves the iterator to the beginning or the end + // finished(dirc): checks whether the iterator has advanced to the end or the + // beginning + // row, col, and data: return the current row, column, and element + // interp_idx and flat_idx: return the current linear index (interp_idx as + // a double the way the interpreter sees it and flat_idx as an + // octave_idx_type) + // step(dirc): steps forward to the next nonzero element + // skip_ahead(args): skips forward to the nearest nonzero element after + // (including) args. Returns true iff the position at args is nonzero +}; // An iterator over full arrays. When the number of dimensions exceeds // 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1) // // This mimics the behavior of the "find" method (both in Octave and Matlab) // on many-dimensional matrices. +// +// userc indicates whether to (true) track row and column values or (false) +// compute them on the fly as needed -template -class array_iterator +template +class array_iterator : public nz_iterator > { private: - const Array& mat; + typedef nz_iterator > Base; - //Actual total number of columns = mat.dims().numel(1) - //can be different from length of row dimension const octave_idx_type totcols; const octave_idx_type numels; - octave_idx_type coli; - octave_idx_type rowj; + rowcol myrc; + octave_idx_type my_idx; template @@ -129,11 +170,14 @@ step_once (dir_handler dirc) { my_idx += dir; - rowj += dir; - if (dirc.is_ended (rowj, mat.rows ())) + if (userc) { - rowj = dirc.begin (mat.rows ()); - coli += dir; + myrc.row += dir; + if (dirc.is_ended (myrc.row, Base::mat.rows ())) + { + myrc.row = dirc.begin (Base::mat.rows ()); + myrc.col += dir; + } } } @@ -147,35 +191,66 @@ } } + template + bool + next_nz (dir_handler dirc) + { + if (is_zero (data ())) + { + step (dirc); + return false; + } + else + return true; + } + + bool + skip_ahead (octave_idx_type i, const rowcol& rc) + { + if (i < my_idx) + return false; + my_idx = i; + if (userc) + myrc = rc; + return next_nz (fdirc); + } public: array_iterator (const Array& arg_mat) - : mat (arg_mat), totcols (arg_mat.dims ().numel (1)), numels ( + : nz_iterator > (arg_mat), totcols (Base::dims.numel (1)), numels ( totcols * arg_mat.rows ()) { - dir_handler dirc; - begin (dirc); + begin (fdirc); } template void begin (dir_handler dirc) { - coli = dirc.begin (totcols); - rowj = dirc.begin (mat.rows ()); - my_idx = dirc.begin (mat.numel ()); + if(userc) + { + myrc.col = dirc.begin (totcols); + myrc.row = dirc.begin (Base::mat.rows ()); + } + my_idx = dirc.begin (Base::mat.numel ()); move_to_nz (dirc); } octave_idx_type col (void) const { - return coli; + if (userc) + return myrc.col; + else + return my_idx / Base::mat.rows (); } octave_idx_type row (void) const { - return rowj; + if (userc) + return myrc.row; + else + return my_idx % Base::mat.rows (); } double interp_idx (void) const @@ -190,7 +265,7 @@ T data (void) const { - return mat.elem (my_idx); + return Base::mat.elem (my_idx); } template @@ -206,13 +281,41 @@ { return dirc.is_ended (my_idx, numels); } + + bool + skip_ahead (octave_idx_type i) + { + return skip_ahead (i, rowcol (i, Base::dims)); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type coli) + { + return skip_ahead (coli * Base::mat.rows () + rowj, + rowcol (rowj, coli, Base::dims)); + } + + bool + skip_ahead (const Array& idxs) + { + //TODO Check bounds + octave_idx_type rowj = idxs(0); + octave_idx_type coli = 0; + for(int i = idxs.numel () - 1; i > 1; --i) { + coli += idxs(i); + coli *= Base::dims(i); + } + coli += idxs(1); + return skip_ahead (rowj, coli); + } }; template -class sparse_iterator +class sparse_iterator : public nz_iterator > { private: - const Sparse& mat; + typedef nz_iterator > Base; + octave_idx_type coli; octave_idx_type my_idx; @@ -221,32 +324,28 @@ adjust_col (dir_handler dirc) { while (!finished (dirc) - && dirc.is_ended (my_idx, mat.cidx (coli), mat.cidx (coli + 1))) + && dirc.is_ended (my_idx, Base::mat.cidx (coli), Base::mat.cidx (coli + 1))) coli += dir; } + void jump_to_row (octave_idx_type rowj); + public: sparse_iterator (const Sparse& arg_mat) : - mat (arg_mat) + nz_iterator > (arg_mat) { - dir_handler dirc; - begin (dirc); + begin (fdirc); } template void begin (dir_handler dirc) { - coli = dirc.begin (mat.cols ()); - my_idx = dirc.begin (mat.nnz ()); + coli = dirc.begin (Base::mat.cols ()); + my_idx = dirc.begin (Base::mat.nnz ()); adjust_col (dirc); } - double - interp_idx (void) const - { - return to_interp_idx (row (), col (), mat.dims ()); - } octave_idx_type col (void) const { @@ -255,12 +354,13 @@ octave_idx_type row (void) const { - return mat.ridx (my_idx); + return Base::mat.ridx (my_idx); } + T data (void) const { - return mat.data (my_idx); + return Base::mat.data (my_idx); } template void @@ -273,15 +373,49 @@ bool finished (dir_handler dirc) const { - return dirc.is_ended (coli, mat.cols ()); + return dirc.is_ended (coli, Base::mat.cols ()); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type arg_coli) + { + //TODO Check bounds + if (arg_coli < coli || (arg_coli == coli && rowj < this->row ())) + { + return false; + } + else if (arg_coli > coli) + { + coli = arg_coli; + my_idx = Base::mat.cidx (arg_coli); + } + jump_to_row (rowj); + return coli == arg_coli && this->row () == rowj; + } + + double + interp_idx (void) const + { + return default_interp_idx(*this); + } + octave_idx_type + flat_idx (void) const + { + return default_flat_idx(*this); + } + bool + skip_ahead (octave_idx_type i) + { + return default_skip_ahead (*this, i); } }; template -class diag_iterator +class diag_iterator : public nz_iterator > { private: - const DiagArray2& mat; + typedef nz_iterator > Base; + octave_idx_type my_idx; template @@ -296,25 +430,19 @@ public: diag_iterator (const DiagArray2& arg_mat) : - mat (arg_mat) + nz_iterator > (arg_mat) { - dir_handler dirc; - begin (dirc); + begin (fdirc); } template void begin (dir_handler dirc) { - my_idx = dirc.begin (mat.diag_length ()); + my_idx = dirc.begin (Base::mat.diag_length ()); move_to_nz (dirc); } - double - interp_idx (void) const - { - return to_interp_idx (row (), col (), mat.dims ()); - } octave_idx_type col (void) const { @@ -325,10 +453,11 @@ { return my_idx; } + T data (void) const { - return mat.dgelem (my_idx); + return Base::mat.dgelem (my_idx); } template void @@ -341,37 +470,58 @@ bool finished (dir_handler dirc) const { - return dirc.is_ended (my_idx, mat.diag_length ()); + return dirc.is_ended (my_idx, Base::mat.diag_length ()); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type coli) + { + if (coli < my_idx) + return false; + my_idx = coli + (rowj > coli); + move_to_nz (fdirc); + return rowj == coli && coli == my_idx; + } + + double + interp_idx (void) const + { + return default_interp_idx(*this); + } + octave_idx_type + flat_idx (void) const + { + return default_flat_idx(*this); + } + bool + skip_ahead (octave_idx_type i) + { + return default_skip_ahead (*this, i); } }; -class perm_iterator +class perm_iterator : public nz_iterator { private: - const PermMatrix& mat; + typedef nz_iterator Base; + octave_idx_type my_idx; public: perm_iterator (const PermMatrix& arg_mat) : - mat (arg_mat) + nz_iterator (arg_mat) { - dir_handler dirc; - begin (dirc); + begin (fdirc); } template void begin (dir_handler dirc) { - my_idx = dirc.begin (mat.cols ()); + my_idx = dirc.begin (Base::mat.cols ()); } octave_idx_type - interp_idx (void) const - { - return to_interp_idx (row (), col (), mat.dims ()); - } - octave_idx_type col (void) const { return my_idx; @@ -379,8 +529,9 @@ octave_idx_type row (void) const { - return mat.perm_elem (my_idx); + return Base::mat.perm_elem (my_idx); } + bool data (void) const { @@ -396,8 +547,68 @@ bool finished (dir_handler dirc) const { - return dirc.is_ended (my_idx, mat.rows ()); + return dirc.is_ended (my_idx, Base::mat.rows ()); + } + + bool + skip_ahead (octave_idx_type rowj, octave_idx_type coli) + { + //TODO Check bounds + if (coli < my_idx || rowj < this->row ()) + return false; + my_idx = coli + (rowj > Base::mat.perm_elem (coli)); + return my_idx == coli && rowj == this->row (); + } + + double + interp_idx (void) const + { + return default_interp_idx(*this); + } + octave_idx_type + flat_idx (void) const + { + return default_flat_idx(*this); + } + bool + skip_ahead (octave_idx_type i) + { + return default_skip_ahead (*this, i); } }; +// Uses a one-sided binary search to move to the next element in this column +// whose row is at least rowj. The one-sided binary search guarantees +// O(log(rowj - currentRow)) time to find it. +template +void +sparse_iterator::jump_to_row (octave_idx_type rowj) +{ + octave_idx_type ub = Base::mat.cidx (coli + 1); + octave_idx_type lo = my_idx - 1; + octave_idx_type hi = my_idx; + octave_idx_type hidiff = 1; + while (Base::mat.ridx (hi) < rowj) + { + lo = hi; + hidiff *= 2; + hi += hidiff; + if (hi >= ub) + { + hi = ub; + break; + } + } + while (hi - lo > 1) + { + octave_idx_type mid = (lo + hi) / 2; + if (Base::mat.ridx (mid) < rowj) + lo = mid; + else + hi = mid; + } + my_idx = hi; + adjust_col (fdirc); +} + #endif diff -r 8d47ce2053f2 -r 3fb030666878 test/logical-index.tst --- a/test/logical-index.tst Mon Jul 14 13:07:59 2014 -0600 +++ b/test/logical-index.tst Fri Jul 25 13:39:31 2014 -0600 @@ -30,13 +30,18 @@ %!assert (isempty (a(logical ([0,0,0,0])))) %!assert (a(logical ([1,1,1,1])), [9,8,7,6]) %!assert (a(logical ([0,1,1,0])), [8,7]) +%!assert (a(logical ([0;1;1;0])), [8,7]) %!assert (a(logical ([1,1])), [9,8]) +%! a = permute(a,[1,3,2]); +%!assert (a(logical ([0,1,1,0])), permute([8,7],[1,3,2])) + %!shared a %! a = [9,8;7,6]; -%!assert (isempty (a(logical ([0,0,0,0])))) +%!assert (a(logical ([0,0,0,0])), zeros(1,0)) %!assert (a(logical ([1,1,1,1])), [9,7,8,6]) %!assert (a(logical ([0,1,1,0])), [7,8]) +%!assert (a(logical ([0;1;1;0])), [7;8]) %!assert (a(logical (0:1),logical (0:1)), 6) %!assert (a(logical (0:1),2:-1:1), [6,7]) %!assert (a(logical (0:1),logical ([0,1])), 6) @@ -71,3 +76,21 @@ %!assert (a(logical ([1,1]),1), [9;7]) %!assert (a(logical ([1,1]),logical ([1,1])), [9,8;7,6]) +%!assert (a(logical (permute([1,1,1,1],[1,3,2]))), [9;7;8;6]) +%!assert (a(logical (permute([0,1,1,0],[1,3,2]))), [7;8]) +%!assert (a(logical([])), []) +%!assert (a(false), []) +%!assert (a(false(0,0,1)), []) +%!assert (a(false(2,2)),zeros(0,1)) +%!assert (a(false(1,4)),zeros(1,0)) +%!assert (a(false(4,1)),zeros(0,1)) +%!assert (a(false(1,1,4)),zeros(0,1)) + +%!shared v,a,b,c +%! v = sparse((1:100000).'); +%! v = v .* mod(v,2); +%! a = sparse(diag(v)); +%! b = logical(speye(100000)); +%! c = [b(:,1:2:end); b(:,2:2:end)]; +%!assert (a(b), v) +%!assert (a(c), v)