comparison liboctave/util/nz-iterators.h @ 19010:3fb030666878 draft default tip dspies

Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior
author David Spies <dnspies@gmail.com>
date Fri, 25 Jul 2014 13:39:31 -0600
parents 8d47ce2053f2
children
comparison
equal deleted inserted replaced
19009:8d47ce2053f2 19010:3fb030666878
20 20
21 */ 21 */
22 #if !defined (octave_nz_iterators_h) 22 #if !defined (octave_nz_iterators_h)
23 #define octave_nz_iterators_h 1 23 #define octave_nz_iterators_h 1
24 24
25 #include <cassert>
26
27 #include "array-iter-decl.h"
25 #include "interp-idx.h" 28 #include "interp-idx.h"
26 #include "oct-inttypes.h" 29 #include "oct-inttypes.h"
27 #include "Array.h" 30 #include "Array.h"
28 #include "DiagArray2.h" 31 #include "DiagArray2.h"
29 #include "PermMatrix.h" 32 #include "PermMatrix.h"
30 #include "Sparse.h" 33 #include "Sparse.h"
31 #include "direction.h" 34 #include "direction.h"
35 #include "is-zero.h"
32 36
33 // This file contains generic column-major iterators over 37 // This file contains generic column-major iterators over
34 // the nonzero elements of any array or matrix. If you have a matrix mat 38 // the nonzero elements of any array or matrix. If you have a matrix mat
35 // of type M, you can construct the proper iterator type using 39 // of type M, you can construct the proper iterator type using
36 // M::iter_type iter(mat) and iter will iterate efficiently (forwards 40 // M::iter_type iter(mat) and iter will iterate efficiently (forwards
59 // octave_idx_type col = iter.col(); 63 // octave_idx_type col = iter.col();
60 // double doub_index = iter.interp_index (); 64 // double doub_index = iter.interp_index ();
61 // T elem = iter.data(); 65 // T elem = iter.data();
62 // // ... Do something with these 66 // // ... Do something with these
63 // } 67 // }
64 // 68
65 // Note that array_iter for indexing over full matrices also includes 69 struct rowcol
66 // a iter.flat_index () method which returns an octave_idx_type. 70 {
67 // 71 rowcol (void) { }
68 // The other iterators to not have a flat_index() method because they 72
69 // risk overflowing octave_idx_type. It is recommended you take care 73 rowcol (octave_idx_type i, const dim_vector& dims)
70 // to implement your function in a way that accounts for this problem. 74 {
71 // 75 #if defined(BOUNDS_CHECKING)
72 // FIXME: I'd like to add in these 76 check_index(i, dims);
73 // default no-parameter versions of 77 #endif
74 // begin() and step() to each of the 78 row = i % dims(0);
75 // classes. But the C++ compiler complains 79 col = i / dims(0);
76 // because apparently I'm not allowed to overload 80 }
77 // templated methods with non-templated ones. Any 81
78 // ideas for work-arounds? 82 rowcol (octave_idx_type rowj, octave_idx_type coli, const dim_vector& dims)
79 // 83 {
80 //#define INCLUDE_DEFAULT_STEPS \ 84 #if defined(BOUNDS_CHECKING)
81 // void begin (void) \ 85 check_index(rowj, coli, dims);
82 // { \ 86 #endif
83 // dir_handler<FORWARD> dirc; \ 87 row = rowj;
84 // begin (dirc); \ 88 col = coli;
85 // } \ 89 }
86 // void step (void) \ 90
87 // { \ 91 octave_idx_type row;
88 // dir_handler<FORWARD> dirc; \ 92 octave_idx_type col;
89 // step (dirc); \ 93 };
90 // } \ 94
91 // bool finished (void) const \ 95 // Default (non-linear) implementation of nz_iterator.interp_idx
92 // { \ 96 template <typename Iter>
93 // dir_handler<FORWARD> dirc; \ 97 inline double
94 // return finished (dirc); \ 98 default_interp_idx(const Iter& it)
95 // } 99 {
96 100 return to_interp_idx (it.row (), it.col (), it.dims);
97 // A generic method for checking if some element of a matrix with
98 // element type T is zero.
99 template<typename T>
100 bool
101 is_zero (T t)
102 {
103 return t == static_cast<T> (0);
104 } 101 }
102
103 // Default (non-linear) implementation of nz_iterator.flat_idx
104 template <typename Iter>
105 inline octave_idx_type
106 default_flat_idx (const Iter& it)
107 {
108 return to_flat_idx (it.row (), it.col (), it.dims);
109 }
110
111 // Default (non-linear) implementation of
112 // nz_iterator.skip_ahead(octave_idx_type)
113 template <typename Iter>
114 inline bool
115 default_skip_ahead (Iter& it, octave_idx_type i)
116 {
117 return it.skip_ahead (i % it.dims(0), i / it.dims(0));
118 }
119
120
121 template <typename M>
122 class nz_iterator
123 {
124 protected:
125 const M mat;
126
127 public:
128 const dim_vector dims;
129 nz_iterator (const M& arg_mat) : mat (arg_mat), dims (arg_mat.dims ()) { }
130
131 // Because of the abundance of templates, this interface cannot be specified
132 // using pure virtual methods. But here's what each method does:
133 //
134 // begin(dirc): moves the iterator to the beginning or the end
135 // finished(dirc): checks whether the iterator has advanced to the end or the
136 // beginning
137 // row, col, and data: return the current row, column, and element
138 // interp_idx and flat_idx: return the current linear index (interp_idx as
139 // a double the way the interpreter sees it and flat_idx as an
140 // octave_idx_type)
141 // step(dirc): steps forward to the next nonzero element
142 // skip_ahead(args): skips forward to the nearest nonzero element after
143 // (including) args. Returns true iff the position at args is nonzero
144 };
105 145
106 // An iterator over full arrays. When the number of dimensions exceeds 146 // An iterator over full arrays. When the number of dimensions exceeds
107 // 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1) 147 // 2, calls to iter.col() may exceed mat.cols() up to mat.dims().numel(1)
108 // 148 //
109 // This mimics the behavior of the "find" method (both in Octave and Matlab) 149 // This mimics the behavior of the "find" method (both in Octave and Matlab)
110 // on many-dimensional matrices. 150 // on many-dimensional matrices.
111 151 //
112 template<typename T> 152 // userc indicates whether to (true) track row and column values or (false)
113 class array_iterator 153 // compute them on the fly as needed
154
155 template<typename T, bool userc>
156 class array_iterator : public nz_iterator<Array<T> >
114 { 157 {
115 private: 158 private:
116 const Array<T>& mat; 159 typedef nz_iterator<Array<T> > Base;
117 160
118 //Actual total number of columns = mat.dims().numel(1)
119 //can be different from length of row dimension
120 const octave_idx_type totcols; 161 const octave_idx_type totcols;
121 const octave_idx_type numels; 162 const octave_idx_type numels;
122 163
123 octave_idx_type coli; 164 rowcol myrc;
124 octave_idx_type rowj; 165
125 octave_idx_type my_idx; 166 octave_idx_type my_idx;
126 167
127 template<direction dir> 168 template<direction dir>
128 void 169 void
129 step_once (dir_handler<dir> dirc) 170 step_once (dir_handler<dir> dirc)
130 { 171 {
131 my_idx += dir; 172 my_idx += dir;
132 rowj += dir; 173 if (userc)
133 if (dirc.is_ended (rowj, mat.rows ()))
134 { 174 {
135 rowj = dirc.begin (mat.rows ()); 175 myrc.row += dir;
136 coli += dir; 176 if (dirc.is_ended (myrc.row, Base::mat.rows ()))
177 {
178 myrc.row = dirc.begin (Base::mat.rows ());
179 myrc.col += dir;
180 }
137 } 181 }
138 } 182 }
139 183
140 template<direction dir> 184 template<direction dir>
141 void 185 void
145 { 189 {
146 step_once (dirc); 190 step_once (dirc);
147 } 191 }
148 } 192 }
149 193
194 template <direction dir>
195 bool
196 next_nz (dir_handler<dir> dirc)
197 {
198 if (is_zero (data ()))
199 {
200 step (dirc);
201 return false;
202 }
203 else
204 return true;
205 }
206
207 bool
208 skip_ahead (octave_idx_type i, const rowcol& rc)
209 {
210 if (i < my_idx)
211 return false;
212 my_idx = i;
213 if (userc)
214 myrc = rc;
215 return next_nz (fdirc);
216 }
150 217
151 public: 218 public:
152 array_iterator (const Array<T>& arg_mat) 219 array_iterator (const Array<T>& arg_mat)
153 : mat (arg_mat), totcols (arg_mat.dims ().numel (1)), numels ( 220 : nz_iterator<Array<T> > (arg_mat), totcols (Base::dims.numel (1)), numels (
154 totcols * arg_mat.rows ()) 221 totcols * arg_mat.rows ())
155 { 222 {
156 dir_handler<FORWARD> dirc; 223 begin (fdirc);
157 begin (dirc);
158 } 224 }
159 225
160 template<direction dir> 226 template<direction dir>
161 void 227 void
162 begin (dir_handler<dir> dirc) 228 begin (dir_handler<dir> dirc)
163 { 229 {
164 coli = dirc.begin (totcols); 230 if(userc)
165 rowj = dirc.begin (mat.rows ()); 231 {
166 my_idx = dirc.begin (mat.numel ()); 232 myrc.col = dirc.begin (totcols);
233 myrc.row = dirc.begin (Base::mat.rows ());
234 }
235 my_idx = dirc.begin (Base::mat.numel ());
167 move_to_nz (dirc); 236 move_to_nz (dirc);
168 } 237 }
169 238
170 octave_idx_type 239 octave_idx_type
171 col (void) const 240 col (void) const
172 { 241 {
173 return coli; 242 if (userc)
243 return myrc.col;
244 else
245 return my_idx / Base::mat.rows ();
174 } 246 }
175 octave_idx_type 247 octave_idx_type
176 row (void) const 248 row (void) const
177 { 249 {
178 return rowj; 250 if (userc)
251 return myrc.row;
252 else
253 return my_idx % Base::mat.rows ();
179 } 254 }
180 double 255 double
181 interp_idx (void) const 256 interp_idx (void) const
182 { 257 {
183 return to_interp_idx (my_idx); 258 return to_interp_idx (my_idx);
188 return my_idx; 263 return my_idx;
189 } 264 }
190 T 265 T
191 data (void) const 266 data (void) const
192 { 267 {
193 return mat.elem (my_idx); 268 return Base::mat.elem (my_idx);
194 } 269 }
195 270
196 template<direction dir> 271 template<direction dir>
197 void 272 void
198 step (dir_handler<dir> dirc) 273 step (dir_handler<dir> dirc)
204 bool 279 bool
205 finished (dir_handler<dir> dirc) const 280 finished (dir_handler<dir> dirc) const
206 { 281 {
207 return dirc.is_ended (my_idx, numels); 282 return dirc.is_ended (my_idx, numels);
208 } 283 }
284
285 bool
286 skip_ahead (octave_idx_type i)
287 {
288 return skip_ahead (i, rowcol (i, Base::dims));
289 }
290
291 bool
292 skip_ahead (octave_idx_type rowj, octave_idx_type coli)
293 {
294 return skip_ahead (coli * Base::mat.rows () + rowj,
295 rowcol (rowj, coli, Base::dims));
296 }
297
298 bool
299 skip_ahead (const Array<octave_idx_type>& idxs)
300 {
301 //TODO Check bounds
302 octave_idx_type rowj = idxs(0);
303 octave_idx_type coli = 0;
304 for(int i = idxs.numel () - 1; i > 1; --i) {
305 coli += idxs(i);
306 coli *= Base::dims(i);
307 }
308 coli += idxs(1);
309 return skip_ahead (rowj, coli);
310 }
209 }; 311 };
210 312
211 template<typename T> 313 template<typename T>
212 class sparse_iterator 314 class sparse_iterator : public nz_iterator<Sparse<T> >
213 { 315 {
214 private: 316 private:
215 const Sparse<T>& mat; 317 typedef nz_iterator<Sparse<T> > Base;
318
216 octave_idx_type coli; 319 octave_idx_type coli;
217 octave_idx_type my_idx; 320 octave_idx_type my_idx;
218 321
219 template<direction dir> 322 template<direction dir>
220 void 323 void
221 adjust_col (dir_handler<dir> dirc) 324 adjust_col (dir_handler<dir> dirc)
222 { 325 {
223 while (!finished (dirc) 326 while (!finished (dirc)
224 && dirc.is_ended (my_idx, mat.cidx (coli), mat.cidx (coli + 1))) 327 && dirc.is_ended (my_idx, Base::mat.cidx (coli), Base::mat.cidx (coli + 1)))
225 coli += dir; 328 coli += dir;
226 } 329 }
330
331 void jump_to_row (octave_idx_type rowj);
227 332
228 public: 333 public:
229 sparse_iterator (const Sparse<T>& arg_mat) : 334 sparse_iterator (const Sparse<T>& arg_mat) :
230 mat (arg_mat) 335 nz_iterator<Sparse<T> > (arg_mat)
231 { 336 {
232 dir_handler<FORWARD> dirc; 337 begin (fdirc);
233 begin (dirc);
234 } 338 }
235 339
236 template<direction dir> 340 template<direction dir>
237 void 341 void
238 begin (dir_handler<dir> dirc) 342 begin (dir_handler<dir> dirc)
239 { 343 {
240 coli = dirc.begin (mat.cols ()); 344 coli = dirc.begin (Base::mat.cols ());
241 my_idx = dirc.begin (mat.nnz ()); 345 my_idx = dirc.begin (Base::mat.nnz ());
242 adjust_col (dirc); 346 adjust_col (dirc);
347 }
348
349 octave_idx_type
350 col (void) const
351 {
352 return coli;
353 }
354 octave_idx_type
355 row (void) const
356 {
357 return Base::mat.ridx (my_idx);
358 }
359
360 T
361 data (void) const
362 {
363 return Base::mat.data (my_idx);
364 }
365 template<direction dir>
366 void
367 step (dir_handler<dir> dirc)
368 {
369 my_idx += dir;
370 adjust_col (dirc);
371 }
372 template<direction dir>
373 bool
374 finished (dir_handler<dir> dirc) const
375 {
376 return dirc.is_ended (coli, Base::mat.cols ());
377 }
378
379 bool
380 skip_ahead (octave_idx_type rowj, octave_idx_type arg_coli)
381 {
382 //TODO Check bounds
383 if (arg_coli < coli || (arg_coli == coli && rowj < this->row ()))
384 {
385 return false;
386 }
387 else if (arg_coli > coli)
388 {
389 coli = arg_coli;
390 my_idx = Base::mat.cidx (arg_coli);
391 }
392 jump_to_row (rowj);
393 return coli == arg_coli && this->row () == rowj;
243 } 394 }
244 395
245 double 396 double
246 interp_idx (void) const 397 interp_idx (void) const
247 { 398 {
248 return to_interp_idx (row (), col (), mat.dims ()); 399 return default_interp_idx(*this);
249 } 400 }
250 octave_idx_type 401 octave_idx_type
251 col (void) const 402 flat_idx (void) const
252 { 403 {
253 return coli; 404 return default_flat_idx(*this);
254 } 405 }
255 octave_idx_type 406 bool
256 row (void) const 407 skip_ahead (octave_idx_type i)
257 { 408 {
258 return mat.ridx (my_idx); 409 return default_skip_ahead (*this, i);
259 }
260 T
261 data (void) const
262 {
263 return mat.data (my_idx);
264 }
265 template<direction dir>
266 void
267 step (dir_handler<dir> dirc)
268 {
269 my_idx += dir;
270 adjust_col (dirc);
271 }
272 template<direction dir>
273 bool
274 finished (dir_handler<dir> dirc) const
275 {
276 return dirc.is_ended (coli, mat.cols ());
277 } 410 }
278 }; 411 };
279 412
280 template<typename T> 413 template<typename T>
281 class diag_iterator 414 class diag_iterator : public nz_iterator <DiagArray2<T> >
282 { 415 {
283 private: 416 private:
284 const DiagArray2<T>& mat; 417 typedef nz_iterator <DiagArray2<T> > Base;
418
285 octave_idx_type my_idx; 419 octave_idx_type my_idx;
286 420
287 template <direction dir> 421 template <direction dir>
288 void 422 void
289 move_to_nz (dir_handler<dir> dirc) 423 move_to_nz (dir_handler<dir> dirc)
294 } 428 }
295 } 429 }
296 430
297 public: 431 public:
298 diag_iterator (const DiagArray2<T>& arg_mat) : 432 diag_iterator (const DiagArray2<T>& arg_mat) :
299 mat (arg_mat) 433 nz_iterator<DiagArray2<T> > (arg_mat)
300 { 434 {
301 dir_handler<FORWARD> dirc; 435 begin (fdirc);
302 begin (dirc);
303 } 436 }
304 437
305 template<direction dir> 438 template<direction dir>
306 void 439 void
307 begin (dir_handler<dir> dirc) 440 begin (dir_handler<dir> dirc)
308 { 441 {
309 my_idx = dirc.begin (mat.diag_length ()); 442 my_idx = dirc.begin (Base::mat.diag_length ());
310 move_to_nz (dirc); 443 move_to_nz (dirc);
444 }
445
446 octave_idx_type
447 col (void) const
448 {
449 return my_idx;
450 }
451 octave_idx_type
452 row (void) const
453 {
454 return my_idx;
455 }
456
457 T
458 data (void) const
459 {
460 return Base::mat.dgelem (my_idx);
461 }
462 template<direction dir>
463 void
464 step (dir_handler<dir> dirc)
465 {
466 my_idx += dir;
467 move_to_nz (dirc);
468 }
469 template<direction dir>
470 bool
471 finished (dir_handler<dir> dirc) const
472 {
473 return dirc.is_ended (my_idx, Base::mat.diag_length ());
474 }
475
476 bool
477 skip_ahead (octave_idx_type rowj, octave_idx_type coli)
478 {
479 if (coli < my_idx)
480 return false;
481 my_idx = coli + (rowj > coli);
482 move_to_nz (fdirc);
483 return rowj == coli && coli == my_idx;
311 } 484 }
312 485
313 double 486 double
314 interp_idx (void) const 487 interp_idx (void) const
315 { 488 {
316 return to_interp_idx (row (), col (), mat.dims ()); 489 return default_interp_idx(*this);
317 } 490 }
318 octave_idx_type 491 octave_idx_type
319 col (void) const 492 flat_idx (void) const
320 { 493 {
321 return my_idx; 494 return default_flat_idx(*this);
322 } 495 }
323 octave_idx_type 496 bool
324 row (void) const 497 skip_ahead (octave_idx_type i)
325 { 498 {
326 return my_idx; 499 return default_skip_ahead (*this, i);
327 }
328 T
329 data (void) const
330 {
331 return mat.dgelem (my_idx);
332 }
333 template<direction dir>
334 void
335 step (dir_handler<dir> dirc)
336 {
337 my_idx += dir;
338 move_to_nz (dirc);
339 }
340 template<direction dir>
341 bool
342 finished (dir_handler<dir> dirc) const
343 {
344 return dirc.is_ended (my_idx, mat.diag_length ());
345 } 500 }
346 }; 501 };
347 502
348 class perm_iterator 503 class perm_iterator : public nz_iterator<PermMatrix>
349 { 504 {
350 private: 505 private:
351 const PermMatrix& mat; 506 typedef nz_iterator<PermMatrix> Base;
507
352 octave_idx_type my_idx; 508 octave_idx_type my_idx;
353 509
354 public: 510 public:
355 perm_iterator (const PermMatrix& arg_mat) : 511 perm_iterator (const PermMatrix& arg_mat) :
356 mat (arg_mat) 512 nz_iterator<PermMatrix> (arg_mat)
357 { 513 {
358 dir_handler<FORWARD> dirc; 514 begin (fdirc);
359 begin (dirc);
360 } 515 }
361 516
362 template<direction dir> 517 template<direction dir>
363 void 518 void
364 begin (dir_handler<dir> dirc) 519 begin (dir_handler<dir> dirc)
365 { 520 {
366 my_idx = dirc.begin (mat.cols ()); 521 my_idx = dirc.begin (Base::mat.cols ());
367 } 522 }
368 523
369 octave_idx_type 524 octave_idx_type
525 col (void) const
526 {
527 return my_idx;
528 }
529 octave_idx_type
530 row (void) const
531 {
532 return Base::mat.perm_elem (my_idx);
533 }
534
535 bool
536 data (void) const
537 {
538 return true;
539 }
540 template<direction dir>
541 void
542 step (dir_handler<dir>)
543 {
544 my_idx += dir;
545 }
546 template<direction dir>
547 bool
548 finished (dir_handler<dir> dirc) const
549 {
550 return dirc.is_ended (my_idx, Base::mat.rows ());
551 }
552
553 bool
554 skip_ahead (octave_idx_type rowj, octave_idx_type coli)
555 {
556 //TODO Check bounds
557 if (coli < my_idx || rowj < this->row ())
558 return false;
559 my_idx = coli + (rowj > Base::mat.perm_elem (coli));
560 return my_idx == coli && rowj == this->row ();
561 }
562
563 double
370 interp_idx (void) const 564 interp_idx (void) const
371 { 565 {
372 return to_interp_idx (row (), col (), mat.dims ()); 566 return default_interp_idx(*this);
373 } 567 }
374 octave_idx_type 568 octave_idx_type
375 col (void) const 569 flat_idx (void) const
376 { 570 {
377 return my_idx; 571 return default_flat_idx(*this);
378 } 572 }
379 octave_idx_type 573 bool
380 row (void) const 574 skip_ahead (octave_idx_type i)
381 { 575 {
382 return mat.perm_elem (my_idx); 576 return default_skip_ahead (*this, i);
383 }
384 bool
385 data (void) const
386 {
387 return true;
388 }
389 template<direction dir>
390 void
391 step (dir_handler<dir>)
392 {
393 my_idx += dir;
394 }
395 template<direction dir>
396 bool
397 finished (dir_handler<dir> dirc) const
398 {
399 return dirc.is_ended (my_idx, mat.rows ());
400 } 577 }
401 }; 578 };
402 579
580 // Uses a one-sided binary search to move to the next element in this column
581 // whose row is at least rowj. The one-sided binary search guarantees
582 // O(log(rowj - currentRow)) time to find it.
583 template<typename T>
584 void
585 sparse_iterator<T>::jump_to_row (octave_idx_type rowj)
586 {
587 octave_idx_type ub = Base::mat.cidx (coli + 1);
588 octave_idx_type lo = my_idx - 1;
589 octave_idx_type hi = my_idx;
590 octave_idx_type hidiff = 1;
591 while (Base::mat.ridx (hi) < rowj)
592 {
593 lo = hi;
594 hidiff *= 2;
595 hi += hidiff;
596 if (hi >= ub)
597 {
598 hi = ub;
599 break;
600 }
601 }
602 while (hi - lo > 1)
603 {
604 octave_idx_type mid = (lo + hi) / 2;
605 if (Base::mat.ridx (mid) < rowj)
606 lo = mid;
607 else
608 hi = mid;
609 }
610 my_idx = hi;
611 adjust_col (fdirc);
612 }
613
403 #endif 614 #endif