comparison liboctave/util/logical-index.h @ 19010:3fb030666878 draft default tip dspies

Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior
author David Spies <dnspies@gmail.com>
date Fri, 25 Jul 2014 13:39:31 -0600
parents
children
comparison
equal deleted inserted replaced
19009:8d47ce2053f2 19010:3fb030666878
1 /*
2
3 Copyright (C) 2014 David Spies
4
5 This file is part of Octave.
6
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20
21 */
22 #if !defined (octave_logical_index_h)
23 #define octave_logical_index_h 1
24
25 #include "Array.h"
26 #include "dim-vector.h"
27 #include "dv-utils.h"
28 #include "direction.h"
29 #include "dispatch.h"
30 #include "nz-iterators.h"
31 #include "ov.h"
32
33 //Reshapes idx to have dims_rows rows (adding extra zeros if necessary)
34 template<typename IM>
35 Sparse<bool>
36 partial_sparse_reshape (const IM& idx, octave_idx_type dims_rows)
37 {
38 typename IM::iter_type idx_iter (idx);
39
40 dim_vector res_dims = dim_vector (idx.nnz (), 1);
41
42 Array<octave_idx_type> rows (res_dims);
43 Array<octave_idx_type> cols (res_dims);
44
45 const octave_idx_type idx_rows = idx.rows ();
46
47 octave_idx_type i;
48 octave_idx_type col_start_col = 0;
49 octave_idx_type col_start_row = 0;
50 octave_idx_type idx_col = 0;
51 for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc);
52 ++i, idx_iter.step (fdirc))
53 {
54 for (; idx_col < idx_iter.col (); ++idx_col)
55 {
56 octave_idx_type next_row = col_start_row + idx_rows;
57 col_start_col += next_row / dims_rows;
58 col_start_row = next_row % dims_rows;
59 }
60 octave_idx_type next_row = col_start_row + idx_iter.row ();
61 cols.xelem (i) = col_start_col + next_row / dims_rows;
62 rows.xelem (i) = next_row % dims_rows;
63 }
64 Array<bool> trueScalar (dim_vector (1, 1), true);
65 return Sparse<bool> (trueScalar, idx_vector (rows), idx_vector (cols));
66 }
67
68 // Given matrix mat and logical matrix idx, takes mat(idx) and returns the
69 // result as an Array. "full" indicates if mat is full (ie can use linear index)
70 // or not. IterM is the nonzero-iterator type for M (the type of mat)
71 // if !full, mat and idx must have the same height
72 template<bool full, typename IterM, typename M, typename IM>
73 Array<typename M::element_type>
74 take_bool_with_index (const M& mat, const IM& idx)
75 {
76 typedef typename M::element_type ELT_T;
77
78 // After testing over 150 edge-cases in Matlab, this seems to be the rule for
79 // logical indexing return-value dimensions
80 const dim_vector idx_dims = idx.dims ();
81 const dim_vector mat_dims = mat.dims ();
82 octave_idx_type idx_nnz = idx.nnz ();
83 dim_vector res_dims;
84 if (idx_dims(0) == 0 && idx_dims(1) == 0)
85 res_dims = dim_vector (0, 0);
86 else if (dv_is_scalar (idx_dims) && idx_nnz == 0)
87 res_dims = dim_vector (0, 0);
88 else if (dv_is_extended_vector (mat_dims))
89 res_dims = dv_match_vector (mat_dims, idx_nnz);
90 else
91 {
92 if (dv_is_row (idx_dims))
93 res_dims = dim_vector (1, idx_nnz);
94 else
95 res_dims = dim_vector (idx_nnz, 1);
96 }
97
98 IterM mat_iter (mat);
99 typename IM::iter_type idx_iter (idx);
100
101 Array<ELT_T> res (res_dims);
102
103 octave_idx_type i;
104 for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc);
105 ++i, idx_iter.step (fdirc))
106 {
107 bool nz;
108 if (full)
109 nz = mat_iter.skip_ahead (idx_iter.flat_idx ());
110 else
111 nz = mat_iter.skip_ahead (idx_iter.row (), idx_iter.col ());
112 if (nz)
113 res.xelem (i) = mat_iter.data ();
114 else
115 res.xelem (i) = static_cast<ELT_T> (0);
116 }
117 return res;
118 }
119
120 // Determine whether we need to reshape idx before calling take_bool_with_index
121 // to ensure mat and idx have the same height
122 template<bool full, typename IterM, typename M, typename IM>
123 Array<typename M::element_type>
124 take_bool_index (const M& mat, const IM& idx)
125 {
126 if (full || idx.rows () == mat.rows ())
127 return take_bool_with_index<full, IterM> (mat, idx);
128 else
129 return take_bool_with_index<full, IterM> (
130 mat, partial_sparse_reshape (idx, mat.rows ()));
131 }
132
133 template<typename IM>
134 Array<bool>
135 bool_index (const PermMatrix& mat, const IM& idx)
136 {
137 return take_bool_index<false, perm_iterator> (mat, idx);
138 }
139
140 template<typename ELT_T, typename IM>
141 Array<ELT_T>
142 bool_index (const DiagArray2<ELT_T>& mat, const IM& idx)
143 {
144 return take_bool_index<false, diag_iterator<ELT_T> > (mat, idx);
145 }
146
147 template<typename ELT_T, typename IM>
148 Array<ELT_T>
149 bool_index (const Array<ELT_T>& mat, const IM& idx)
150 {
151 return take_bool_index<true, array_iterator<ELT_T, false> > (mat, idx);
152 }
153
154 template<typename ELT_T, typename IM>
155 Sparse<ELT_T>
156 bool_index (const Sparse<ELT_T>& mat, const IM& idx)
157 {
158 const Array<ELT_T> res = take_bool_index<false, sparse_iterator<ELT_T> > (
159 mat, idx);
160 return Sparse<ELT_T> (res);
161 }
162
163
164 template<typename M>
165 struct mwrapper
166 {
167 template<typename IM>
168 struct idx_caller
169 {
170 octave_value_list
171 operator() (const IM& arg, const M& into)
172 {
173 octave_value_list res (1);
174 res(0) = bool_index (into, arg);
175 return res;
176 }
177 };
178
179 // Dispatch call to bool_index for type of idx. The nested functor allows
180 // for multiple template parameters (since dispatch assumes its template
181 // argument has exactly one template parameter)
182 static octave_value
183 do_call (const M& mat, const octave_value& idx)
184 {
185 octave_value_list res = dispatch<idx_caller> (idx, mat, "bool_index");
186 return res(0);
187 }
188 };
189
190 template<typename M>
191 octave_value
192 call_bool_index (const M& mat, const octave_value& idx)
193 {
194 return mwrapper<M>::do_call (mat, idx);
195 }
196
197 #endif