Mercurial > octave-nkf
comparison src/DLD-FUNCTIONS/cellfun.cc @ 8994:a8d30dc1beec
cellfun optimizations
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 18 Mar 2009 12:06:46 +0100 |
parents | 193804a4f82f |
children | 97aa01a85ea4 |
comparison
equal
deleted
inserted
replaced
8993:6769599e3458 | 8994:a8d30dc1beec |
---|---|
26 #endif | 26 #endif |
27 | 27 |
28 #include <string> | 28 #include <string> |
29 #include <vector> | 29 #include <vector> |
30 #include <list> | 30 #include <list> |
31 #include <memory> | |
31 | 32 |
32 #include "lo-mappers.h" | 33 #include "lo-mappers.h" |
33 #include "oct-locbuf.h" | 34 #include "oct-locbuf.h" |
34 | 35 |
35 #include "Cell.h" | 36 #include "Cell.h" |
37 #include "defun-dld.h" | 38 #include "defun-dld.h" |
38 #include "parse.h" | 39 #include "parse.h" |
39 #include "variables.h" | 40 #include "variables.h" |
40 #include "ov-colon.h" | 41 #include "ov-colon.h" |
41 #include "unwind-prot.h" | 42 #include "unwind-prot.h" |
43 | |
44 // Rationale: | |
45 // The octave_base_value::subsasgn method carries too much overhead for | |
46 // per-element assignment strategy. | |
47 // This class will optimize the most optimistic and most likely case | |
48 // when the output really is scalar by defining a hierarchy of virtual | |
49 // collectors specialized for some scalar types. | |
50 | |
51 class scalar_col_helper | |
52 { | |
53 public: | |
54 virtual bool collect (octave_idx_type i, const octave_value& val) = 0; | |
55 virtual octave_value result (void) = 0; | |
56 virtual ~scalar_col_helper (void) { } | |
57 }; | |
58 | |
59 // The default collector represents what was previously done in the main loop. | |
60 // This reuses the existing assignment machinery via octave_value::subsasgn, | |
61 // which can perform all sorts of conversions, but is relatively slow. | |
62 | |
63 class scalar_col_helper_def : public scalar_col_helper | |
64 { | |
65 std::list<octave_value_list> idx_list; | |
66 octave_value resval; | |
67 public: | |
68 scalar_col_helper_def (const octave_value& val, const dim_vector& dims) | |
69 : idx_list (1), resval (val) | |
70 { | |
71 idx_list.front ().resize (1); | |
72 if (resval.dims () != dims) | |
73 resval.resize (dims); | |
74 } | |
75 ~scalar_col_helper_def (void) { } | |
76 | |
77 bool collect (octave_idx_type i, const octave_value& val) | |
78 { | |
79 if (val.numel () == 1) | |
80 { | |
81 idx_list.front ()(0) = static_cast<double> (i + 1); | |
82 resval = resval.subsasgn ("(", idx_list, val); | |
83 } | |
84 else | |
85 error ("cellfun: expecting all values to be scalars for UniformOutput = true"); | |
86 | |
87 return true; | |
88 } | |
89 octave_value result (void) | |
90 { | |
91 return resval; | |
92 } | |
93 }; | |
94 | |
95 template <class T> | |
96 struct scalar_query_helper { }; | |
97 | |
98 #define DEF_QUERY_HELPER(T, TEST, QUERY) \ | |
99 template <> \ | |
100 struct scalar_query_helper<T> \ | |
101 { \ | |
102 static bool has_value (const octave_value& val) \ | |
103 { return TEST; } \ | |
104 static T get_value (const octave_value& val) \ | |
105 { return QUERY; } \ | |
106 } | |
107 | |
108 DEF_QUERY_HELPER (double, val.is_real_scalar (), val.scalar_value ()); | |
109 DEF_QUERY_HELPER (Complex, val.is_complex_scalar (), val.complex_value ()); | |
110 DEF_QUERY_HELPER (float, val.is_single_type () && val.is_real_scalar (), | |
111 val.float_scalar_value ()); | |
112 DEF_QUERY_HELPER (FloatComplex, val.is_single_type () && val.is_complex_scalar (), | |
113 val.float_complex_value ()); | |
114 DEF_QUERY_HELPER (bool, val.is_bool_scalar (), val.bool_value ()); | |
115 // FIXME: More? | |
116 | |
117 // This specializes for collecting elements of a single type, by accessing | |
118 // an array directly. If the scalar is not valid, it returns false. | |
119 | |
120 template <class NDA> | |
121 class scalar_col_helper_nda : public scalar_col_helper | |
122 { | |
123 NDA arrayval; | |
124 typedef typename NDA::element_type T; | |
125 public: | |
126 scalar_col_helper_nda (const octave_value& val, const dim_vector& dims) | |
127 : arrayval (dims) | |
128 { | |
129 arrayval(0) = scalar_query_helper<T>::get_value (val); | |
130 } | |
131 ~scalar_col_helper_nda (void) { } | |
132 | |
133 bool collect (octave_idx_type i, const octave_value& val) | |
134 { | |
135 bool retval = scalar_query_helper<T>::has_value (val); | |
136 if (retval) | |
137 arrayval(i) = scalar_query_helper<T>::get_value (val); | |
138 return retval; | |
139 } | |
140 octave_value result (void) | |
141 { | |
142 return arrayval; | |
143 } | |
144 }; | |
145 | |
146 template class scalar_col_helper_nda<NDArray>; | |
147 template class scalar_col_helper_nda<FloatNDArray>; | |
148 template class scalar_col_helper_nda<ComplexNDArray>; | |
149 template class scalar_col_helper_nda<FloatComplexNDArray>; | |
150 template class scalar_col_helper_nda<boolNDArray>; | |
151 | |
152 // the virtual constructor. | |
153 scalar_col_helper * | |
154 make_col_helper (const octave_value& val, const dim_vector& dims) | |
155 { | |
156 scalar_col_helper *retval; | |
157 | |
158 if (val.is_bool_scalar ()) | |
159 retval = new scalar_col_helper_nda<boolNDArray> (val, dims); | |
160 else if (val.is_complex_scalar ()) | |
161 { | |
162 if (val.is_single_type ()) | |
163 retval = new scalar_col_helper_nda<FloatComplexNDArray> (val, dims); | |
164 else | |
165 retval = new scalar_col_helper_nda<ComplexNDArray> (val, dims); | |
166 } | |
167 else if (val.is_real_scalar ()) | |
168 { | |
169 if (val.is_single_type ()) | |
170 retval = new scalar_col_helper_nda<FloatNDArray> (val, dims); | |
171 else | |
172 retval = new scalar_col_helper_nda<NDArray> (val, dims); | |
173 } | |
174 else | |
175 retval = new scalar_col_helper_def (val, dims); | |
176 | |
177 return retval; | |
178 } | |
42 | 179 |
43 DEFUN_DLD (cellfun, args, nargout, | 180 DEFUN_DLD (cellfun, args, nargout, |
44 "-*- texinfo -*-\n\ | 181 "-*- texinfo -*-\n\ |
45 @deftypefn {Loadable Function} {} cellfun (@var{name}, @var{c})\n\ | 182 @deftypefn {Loadable Function} {} cellfun (@var{name}, @var{c})\n\ |
46 @deftypefnx {Loadable Function} {} cellfun (\"size\", @var{c}, @var{k})\n\ | 183 @deftypefnx {Loadable Function} {} cellfun (\"size\", @var{c}, @var{k})\n\ |
162 error ("cellfun: second argument must be a cell array"); | 299 error ("cellfun: second argument must be a cell array"); |
163 | 300 |
164 return retval; | 301 return retval; |
165 } | 302 } |
166 | 303 |
167 Cell f_args = args(1).cell_value (); | 304 const Cell f_args = args(1).cell_value (); |
168 | 305 |
169 octave_idx_type k = f_args.numel (); | 306 octave_idx_type k = f_args.numel (); |
170 | 307 |
171 if (name == "isempty") | 308 if (name == "isempty") |
172 { | 309 { |
201 NDArray result (f_args.dims ()); | 338 NDArray result (f_args.dims ()); |
202 for (octave_idx_type count = 0; count < k ; count++) | 339 for (octave_idx_type count = 0; count < k ; count++) |
203 result(count) = static_cast<double> (f_args.elem(count).ndims ()); | 340 result(count) = static_cast<double> (f_args.elem(count).ndims ()); |
204 retval(0) = result; | 341 retval(0) = result; |
205 } | 342 } |
206 else if (name == "prodofsize") | 343 else if (name == "prodofsize" || name == "numel") |
207 { | 344 { |
208 NDArray result (f_args.dims ()); | 345 NDArray result (f_args.dims ()); |
209 for (octave_idx_type count = 0; count < k ; count++) | 346 for (octave_idx_type count = 0; count < k ; count++) |
210 result(count) = static_cast<double> (f_args.elem(count).numel ()); | 347 result(count) = static_cast<double> (f_args.elem(count).numel ()); |
211 retval(0) = result; | 348 retval(0) = result; |
269 | 406 |
270 if (! func) | 407 if (! func) |
271 error ("unknown function"); | 408 error ("unknown function"); |
272 else | 409 else |
273 { | 410 { |
274 octave_value_list idx; | |
275 octave_value_list inputlist; | 411 octave_value_list inputlist; |
276 bool uniform_output = true; | 412 bool uniform_output = true; |
277 bool have_error_handler = false; | 413 bool have_error_handler = false; |
278 std::string err_name; | 414 std::string err_name; |
279 octave_function *error_handler = 0; | 415 octave_function *error_handler = 0; |
280 int offset = 1; | 416 int offset = 1; |
281 int i = 1; | 417 int i = 1; |
282 OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin); | 418 OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin); |
419 // This is to prevent copy-on-write. | |
420 const Cell *cinputs = inputs; | |
283 | 421 |
284 while (i < nargin) | 422 while (i < nargin) |
285 { | 423 { |
286 if (args(i).is_string()) | 424 if (args(i).is_string()) |
287 { | 425 { |
343 } | 481 } |
344 i++; | 482 i++; |
345 } | 483 } |
346 } | 484 } |
347 | 485 |
348 inputlist.resize(nargin-offset); | 486 nargin -= offset; |
487 inputlist.resize(nargin); | |
349 | 488 |
350 if (have_error_handler) | 489 if (have_error_handler) |
351 buffer_error_messages++; | 490 buffer_error_messages++; |
352 | 491 |
353 if (uniform_output) | 492 if (uniform_output) |
354 { | 493 { |
355 retval.resize(nargout); | 494 OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout); |
356 | 495 |
357 for (octave_idx_type count = 0; count < k ; count++) | 496 for (octave_idx_type count = 0; count < k ; count++) |
358 { | 497 { |
359 for (int j = 0; j < nargin-offset; j++) | 498 for (int j = 0; j < nargin; j++) |
360 inputlist(j) = inputs[j](count); | 499 inputlist(j) = cinputs[j](count); |
361 | 500 |
362 octave_value_list tmp = feval (func, inputlist, nargout); | 501 octave_value_list tmp = feval (func, inputlist, nargout); |
363 | 502 |
364 if (error_state && have_error_handler) | 503 if (error_state && have_error_handler) |
365 { | 504 { |
389 | 528 |
390 if (count == 0) | 529 if (count == 0) |
391 { | 530 { |
392 for (int j = 0; j < nargout; j++) | 531 for (int j = 0; j < nargout; j++) |
393 { | 532 { |
394 octave_value val; | 533 octave_value val = tmp(j); |
395 val = tmp(j); | 534 |
396 | 535 if (val.numel () == 1) |
397 if (error_state) | 536 retptr[j].reset (make_col_helper (val, f_args.dims ())); |
398 goto cellfun_err; | 537 else |
399 | 538 { |
400 retval(j) = val.resize(f_args.dims()); | 539 error ("cellfun: expecting all values to be scalars for UniformOutput = true"); |
540 break; | |
541 } | |
401 } | 542 } |
402 } | 543 } |
403 else | 544 else |
404 { | 545 { |
405 idx(0) = octave_value (static_cast<double>(count+1)); | |
406 for (int j = 0; j < nargout; j++) | 546 for (int j = 0; j < nargout; j++) |
407 { | 547 { |
408 // FIXME -- need an easier way to express | |
409 // this test. | |
410 octave_value val = tmp(j); | 548 octave_value val = tmp(j); |
411 | 549 |
412 if (val.ndims () == 2 | 550 if (! retptr[j]->collect (count, val)) |
413 && val.rows () == 1 && val.columns () == 1) | 551 { |
414 retval(j) = | 552 // FIXME: A more elaborate structure would allow again a virtual |
415 retval(j).subsasgn ("(", | 553 // constructor here. |
416 std::list<octave_value_list> | 554 retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), |
417 (1, idx(0)), val); | 555 f_args.dims ())); |
418 else | 556 retptr[j]->collect (count, val); |
419 error ("cellfun: expecting all values to be scalars for UniformOutput = true"); | 557 } |
420 } | 558 } |
421 } | 559 } |
422 | 560 |
423 if (error_state) | 561 if (error_state) |
424 break; | 562 break; |
425 } | 563 } |
564 | |
565 retval.resize (nargout); | |
566 for (int j = 0; j < nargout; j++) | |
567 { | |
568 if (retptr[j].get ()) | |
569 retval(j) = retptr[j]->result (); | |
570 else | |
571 retval(j) = Matrix (); | |
572 } | |
426 } | 573 } |
427 else | 574 else |
428 { | 575 { |
429 OCTAVE_LOCAL_BUFFER (Cell, results, nargout); | 576 OCTAVE_LOCAL_BUFFER (Cell, results, nargout); |
430 for (int j = 0; j < nargout; j++) | 577 for (int j = 0; j < nargout; j++) |
431 results[j].resize(f_args.dims()); | 578 results[j].resize(f_args.dims()); |
432 | 579 |
433 for (octave_idx_type count = 0; count < k ; count++) | 580 for (octave_idx_type count = 0; count < k ; count++) |
434 { | 581 { |
435 for (int j = 0; j < nargin-offset; j++) | 582 for (int j = 0; j < nargin; j++) |
436 inputlist(j) = inputs[j](count); | 583 inputlist(j) = cinputs[j](count); |
437 | 584 |
438 octave_value_list tmp = feval (func, inputlist, nargout); | 585 octave_value_list tmp = feval (func, inputlist, nargout); |
439 | 586 |
440 if (error_state && have_error_handler) | 587 if (error_state && have_error_handler) |
441 { | 588 { |