Mercurial > octave-dspies
comparison liboctave/util/logical-index.h @ 19010:3fb030666878 draft default tip dspies
Added special-case logical-indexing function
* logical-index.h (New file) : Logical-indexing function. May be called on
octave_value types via call_bool_index
* nz-iterators.h : Add base-class nz_iterator for iterator types. Array has
template bool for whether to internally store row-col or compute on the fly
Add skip_ahead method which skips forward to the next nonzero after its
argument
Add flat_index for computing octave_idx_type index of current position (with
assertion failure in the case of overflow)
Move is_zero to separate file
* ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc
(do_index_op): Add call to call_bool_index in logical-index.h
* Array.h : Move forward-declaration for array_iterator to separate header file
* dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx)
* array-iter-decl.h (New file): Header file for forward declaration of
array-iterator
* direction.h : Add constants fdirc and bdirc to avoid having to reconstruct
them
* dv-utils.h, dv-utils.cc (New files) :
Utility functions for querying and constructing dim-vectors
* idx-bounds.h (New file) :
Utility constants and functions for determining whether things will overflow
the maximum allowed bounds
* interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear
index of octave_idx_type
* is-zero.h (New file) : Function for determining whether an element is zero
* logical-index.tst : Add tests for correct return-value dimensions and large
sparse matrix behavior
author | David Spies <dnspies@gmail.com> |
---|---|
date | Fri, 25 Jul 2014 13:39:31 -0600 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
19009:8d47ce2053f2 | 19010:3fb030666878 |
---|---|
1 /* | |
2 | |
3 Copyright (C) 2014 David Spies | |
4 | |
5 This file is part of Octave. | |
6 | |
7 Octave is free software; you can redistribute it and/or modify it | |
8 under the terms of the GNU General Public License as published by the | |
9 Free Software Foundation; either version 3 of the License, or (at your | |
10 option) any later version. | |
11 | |
12 Octave is distributed in the hope that it will be useful, but WITHOUT | |
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or | |
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License | |
15 for more details. | |
16 | |
17 You should have received a copy of the GNU General Public License | |
18 along with Octave; see the file COPYING. If not, see | |
19 <http://www.gnu.org/licenses/>. | |
20 | |
21 */ | |
22 #if !defined (octave_logical_index_h) | |
23 #define octave_logical_index_h 1 | |
24 | |
25 #include "Array.h" | |
26 #include "dim-vector.h" | |
27 #include "dv-utils.h" | |
28 #include "direction.h" | |
29 #include "dispatch.h" | |
30 #include "nz-iterators.h" | |
31 #include "ov.h" | |
32 | |
33 //Reshapes idx to have dims_rows rows (adding extra zeros if necessary) | |
34 template<typename IM> | |
35 Sparse<bool> | |
36 partial_sparse_reshape (const IM& idx, octave_idx_type dims_rows) | |
37 { | |
38 typename IM::iter_type idx_iter (idx); | |
39 | |
40 dim_vector res_dims = dim_vector (idx.nnz (), 1); | |
41 | |
42 Array<octave_idx_type> rows (res_dims); | |
43 Array<octave_idx_type> cols (res_dims); | |
44 | |
45 const octave_idx_type idx_rows = idx.rows (); | |
46 | |
47 octave_idx_type i; | |
48 octave_idx_type col_start_col = 0; | |
49 octave_idx_type col_start_row = 0; | |
50 octave_idx_type idx_col = 0; | |
51 for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc); | |
52 ++i, idx_iter.step (fdirc)) | |
53 { | |
54 for (; idx_col < idx_iter.col (); ++idx_col) | |
55 { | |
56 octave_idx_type next_row = col_start_row + idx_rows; | |
57 col_start_col += next_row / dims_rows; | |
58 col_start_row = next_row % dims_rows; | |
59 } | |
60 octave_idx_type next_row = col_start_row + idx_iter.row (); | |
61 cols.xelem (i) = col_start_col + next_row / dims_rows; | |
62 rows.xelem (i) = next_row % dims_rows; | |
63 } | |
64 Array<bool> trueScalar (dim_vector (1, 1), true); | |
65 return Sparse<bool> (trueScalar, idx_vector (rows), idx_vector (cols)); | |
66 } | |
67 | |
68 // Given matrix mat and logical matrix idx, takes mat(idx) and returns the | |
69 // result as an Array. "full" indicates if mat is full (ie can use linear index) | |
70 // or not. IterM is the nonzero-iterator type for M (the type of mat) | |
71 // if !full, mat and idx must have the same height | |
72 template<bool full, typename IterM, typename M, typename IM> | |
73 Array<typename M::element_type> | |
74 take_bool_with_index (const M& mat, const IM& idx) | |
75 { | |
76 typedef typename M::element_type ELT_T; | |
77 | |
78 // After testing over 150 edge-cases in Matlab, this seems to be the rule for | |
79 // logical indexing return-value dimensions | |
80 const dim_vector idx_dims = idx.dims (); | |
81 const dim_vector mat_dims = mat.dims (); | |
82 octave_idx_type idx_nnz = idx.nnz (); | |
83 dim_vector res_dims; | |
84 if (idx_dims(0) == 0 && idx_dims(1) == 0) | |
85 res_dims = dim_vector (0, 0); | |
86 else if (dv_is_scalar (idx_dims) && idx_nnz == 0) | |
87 res_dims = dim_vector (0, 0); | |
88 else if (dv_is_extended_vector (mat_dims)) | |
89 res_dims = dv_match_vector (mat_dims, idx_nnz); | |
90 else | |
91 { | |
92 if (dv_is_row (idx_dims)) | |
93 res_dims = dim_vector (1, idx_nnz); | |
94 else | |
95 res_dims = dim_vector (idx_nnz, 1); | |
96 } | |
97 | |
98 IterM mat_iter (mat); | |
99 typename IM::iter_type idx_iter (idx); | |
100 | |
101 Array<ELT_T> res (res_dims); | |
102 | |
103 octave_idx_type i; | |
104 for (i = 0, idx_iter.begin (fdirc); !idx_iter.finished (fdirc); | |
105 ++i, idx_iter.step (fdirc)) | |
106 { | |
107 bool nz; | |
108 if (full) | |
109 nz = mat_iter.skip_ahead (idx_iter.flat_idx ()); | |
110 else | |
111 nz = mat_iter.skip_ahead (idx_iter.row (), idx_iter.col ()); | |
112 if (nz) | |
113 res.xelem (i) = mat_iter.data (); | |
114 else | |
115 res.xelem (i) = static_cast<ELT_T> (0); | |
116 } | |
117 return res; | |
118 } | |
119 | |
120 // Determine whether we need to reshape idx before calling take_bool_with_index | |
121 // to ensure mat and idx have the same height | |
122 template<bool full, typename IterM, typename M, typename IM> | |
123 Array<typename M::element_type> | |
124 take_bool_index (const M& mat, const IM& idx) | |
125 { | |
126 if (full || idx.rows () == mat.rows ()) | |
127 return take_bool_with_index<full, IterM> (mat, idx); | |
128 else | |
129 return take_bool_with_index<full, IterM> ( | |
130 mat, partial_sparse_reshape (idx, mat.rows ())); | |
131 } | |
132 | |
133 template<typename IM> | |
134 Array<bool> | |
135 bool_index (const PermMatrix& mat, const IM& idx) | |
136 { | |
137 return take_bool_index<false, perm_iterator> (mat, idx); | |
138 } | |
139 | |
140 template<typename ELT_T, typename IM> | |
141 Array<ELT_T> | |
142 bool_index (const DiagArray2<ELT_T>& mat, const IM& idx) | |
143 { | |
144 return take_bool_index<false, diag_iterator<ELT_T> > (mat, idx); | |
145 } | |
146 | |
147 template<typename ELT_T, typename IM> | |
148 Array<ELT_T> | |
149 bool_index (const Array<ELT_T>& mat, const IM& idx) | |
150 { | |
151 return take_bool_index<true, array_iterator<ELT_T, false> > (mat, idx); | |
152 } | |
153 | |
154 template<typename ELT_T, typename IM> | |
155 Sparse<ELT_T> | |
156 bool_index (const Sparse<ELT_T>& mat, const IM& idx) | |
157 { | |
158 const Array<ELT_T> res = take_bool_index<false, sparse_iterator<ELT_T> > ( | |
159 mat, idx); | |
160 return Sparse<ELT_T> (res); | |
161 } | |
162 | |
163 | |
164 template<typename M> | |
165 struct mwrapper | |
166 { | |
167 template<typename IM> | |
168 struct idx_caller | |
169 { | |
170 octave_value_list | |
171 operator() (const IM& arg, const M& into) | |
172 { | |
173 octave_value_list res (1); | |
174 res(0) = bool_index (into, arg); | |
175 return res; | |
176 } | |
177 }; | |
178 | |
179 // Dispatch call to bool_index for type of idx. The nested functor allows | |
180 // for multiple template parameters (since dispatch assumes its template | |
181 // argument has exactly one template parameter) | |
182 static octave_value | |
183 do_call (const M& mat, const octave_value& idx) | |
184 { | |
185 octave_value_list res = dispatch<idx_caller> (idx, mat, "bool_index"); | |
186 return res(0); | |
187 } | |
188 }; | |
189 | |
190 template<typename M> | |
191 octave_value | |
192 call_bool_index (const M& mat, const octave_value& idx) | |
193 { | |
194 return mwrapper<M>::do_call (mat, idx); | |
195 } | |
196 | |
197 #endif |