comparison src/DLD-FUNCTIONS/cellfun.cc @ 8994:a8d30dc1beec

cellfun optimizations
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 18 Mar 2009 12:06:46 +0100
parents 193804a4f82f
children 97aa01a85ea4
comparison
equal deleted inserted replaced
8993:6769599e3458 8994:a8d30dc1beec
26 #endif 26 #endif
27 27
28 #include <string> 28 #include <string>
29 #include <vector> 29 #include <vector>
30 #include <list> 30 #include <list>
31 #include <memory>
31 32
32 #include "lo-mappers.h" 33 #include "lo-mappers.h"
33 #include "oct-locbuf.h" 34 #include "oct-locbuf.h"
34 35
35 #include "Cell.h" 36 #include "Cell.h"
37 #include "defun-dld.h" 38 #include "defun-dld.h"
38 #include "parse.h" 39 #include "parse.h"
39 #include "variables.h" 40 #include "variables.h"
40 #include "ov-colon.h" 41 #include "ov-colon.h"
41 #include "unwind-prot.h" 42 #include "unwind-prot.h"
43
44 // Rationale:
45 // The octave_base_value::subsasgn method carries too much overhead for
46 // per-element assignment strategy.
47 // This class will optimize the most optimistic and most likely case
48 // when the output really is scalar by defining a hierarchy of virtual
49 // collectors specialized for some scalar types.
50
51 class scalar_col_helper
52 {
53 public:
54 virtual bool collect (octave_idx_type i, const octave_value& val) = 0;
55 virtual octave_value result (void) = 0;
56 virtual ~scalar_col_helper (void) { }
57 };
58
59 // The default collector represents what was previously done in the main loop.
60 // This reuses the existing assignment machinery via octave_value::subsasgn,
61 // which can perform all sorts of conversions, but is relatively slow.
62
63 class scalar_col_helper_def : public scalar_col_helper
64 {
65 std::list<octave_value_list> idx_list;
66 octave_value resval;
67 public:
68 scalar_col_helper_def (const octave_value& val, const dim_vector& dims)
69 : idx_list (1), resval (val)
70 {
71 idx_list.front ().resize (1);
72 if (resval.dims () != dims)
73 resval.resize (dims);
74 }
75 ~scalar_col_helper_def (void) { }
76
77 bool collect (octave_idx_type i, const octave_value& val)
78 {
79 if (val.numel () == 1)
80 {
81 idx_list.front ()(0) = static_cast<double> (i + 1);
82 resval = resval.subsasgn ("(", idx_list, val);
83 }
84 else
85 error ("cellfun: expecting all values to be scalars for UniformOutput = true");
86
87 return true;
88 }
89 octave_value result (void)
90 {
91 return resval;
92 }
93 };
94
95 template <class T>
96 struct scalar_query_helper { };
97
98 #define DEF_QUERY_HELPER(T, TEST, QUERY) \
99 template <> \
100 struct scalar_query_helper<T> \
101 { \
102 static bool has_value (const octave_value& val) \
103 { return TEST; } \
104 static T get_value (const octave_value& val) \
105 { return QUERY; } \
106 }
107
108 DEF_QUERY_HELPER (double, val.is_real_scalar (), val.scalar_value ());
109 DEF_QUERY_HELPER (Complex, val.is_complex_scalar (), val.complex_value ());
110 DEF_QUERY_HELPER (float, val.is_single_type () && val.is_real_scalar (),
111 val.float_scalar_value ());
112 DEF_QUERY_HELPER (FloatComplex, val.is_single_type () && val.is_complex_scalar (),
113 val.float_complex_value ());
114 DEF_QUERY_HELPER (bool, val.is_bool_scalar (), val.bool_value ());
115 // FIXME: More?
116
117 // This specializes for collecting elements of a single type, by accessing
118 // an array directly. If the scalar is not valid, it returns false.
119
120 template <class NDA>
121 class scalar_col_helper_nda : public scalar_col_helper
122 {
123 NDA arrayval;
124 typedef typename NDA::element_type T;
125 public:
126 scalar_col_helper_nda (const octave_value& val, const dim_vector& dims)
127 : arrayval (dims)
128 {
129 arrayval(0) = scalar_query_helper<T>::get_value (val);
130 }
131 ~scalar_col_helper_nda (void) { }
132
133 bool collect (octave_idx_type i, const octave_value& val)
134 {
135 bool retval = scalar_query_helper<T>::has_value (val);
136 if (retval)
137 arrayval(i) = scalar_query_helper<T>::get_value (val);
138 return retval;
139 }
140 octave_value result (void)
141 {
142 return arrayval;
143 }
144 };
145
146 template class scalar_col_helper_nda<NDArray>;
147 template class scalar_col_helper_nda<FloatNDArray>;
148 template class scalar_col_helper_nda<ComplexNDArray>;
149 template class scalar_col_helper_nda<FloatComplexNDArray>;
150 template class scalar_col_helper_nda<boolNDArray>;
151
152 // the virtual constructor.
153 scalar_col_helper *
154 make_col_helper (const octave_value& val, const dim_vector& dims)
155 {
156 scalar_col_helper *retval;
157
158 if (val.is_bool_scalar ())
159 retval = new scalar_col_helper_nda<boolNDArray> (val, dims);
160 else if (val.is_complex_scalar ())
161 {
162 if (val.is_single_type ())
163 retval = new scalar_col_helper_nda<FloatComplexNDArray> (val, dims);
164 else
165 retval = new scalar_col_helper_nda<ComplexNDArray> (val, dims);
166 }
167 else if (val.is_real_scalar ())
168 {
169 if (val.is_single_type ())
170 retval = new scalar_col_helper_nda<FloatNDArray> (val, dims);
171 else
172 retval = new scalar_col_helper_nda<NDArray> (val, dims);
173 }
174 else
175 retval = new scalar_col_helper_def (val, dims);
176
177 return retval;
178 }
42 179
43 DEFUN_DLD (cellfun, args, nargout, 180 DEFUN_DLD (cellfun, args, nargout,
44 "-*- texinfo -*-\n\ 181 "-*- texinfo -*-\n\
45 @deftypefn {Loadable Function} {} cellfun (@var{name}, @var{c})\n\ 182 @deftypefn {Loadable Function} {} cellfun (@var{name}, @var{c})\n\
46 @deftypefnx {Loadable Function} {} cellfun (\"size\", @var{c}, @var{k})\n\ 183 @deftypefnx {Loadable Function} {} cellfun (\"size\", @var{c}, @var{k})\n\
162 error ("cellfun: second argument must be a cell array"); 299 error ("cellfun: second argument must be a cell array");
163 300
164 return retval; 301 return retval;
165 } 302 }
166 303
167 Cell f_args = args(1).cell_value (); 304 const Cell f_args = args(1).cell_value ();
168 305
169 octave_idx_type k = f_args.numel (); 306 octave_idx_type k = f_args.numel ();
170 307
171 if (name == "isempty") 308 if (name == "isempty")
172 { 309 {
201 NDArray result (f_args.dims ()); 338 NDArray result (f_args.dims ());
202 for (octave_idx_type count = 0; count < k ; count++) 339 for (octave_idx_type count = 0; count < k ; count++)
203 result(count) = static_cast<double> (f_args.elem(count).ndims ()); 340 result(count) = static_cast<double> (f_args.elem(count).ndims ());
204 retval(0) = result; 341 retval(0) = result;
205 } 342 }
206 else if (name == "prodofsize") 343 else if (name == "prodofsize" || name == "numel")
207 { 344 {
208 NDArray result (f_args.dims ()); 345 NDArray result (f_args.dims ());
209 for (octave_idx_type count = 0; count < k ; count++) 346 for (octave_idx_type count = 0; count < k ; count++)
210 result(count) = static_cast<double> (f_args.elem(count).numel ()); 347 result(count) = static_cast<double> (f_args.elem(count).numel ());
211 retval(0) = result; 348 retval(0) = result;
269 406
270 if (! func) 407 if (! func)
271 error ("unknown function"); 408 error ("unknown function");
272 else 409 else
273 { 410 {
274 octave_value_list idx;
275 octave_value_list inputlist; 411 octave_value_list inputlist;
276 bool uniform_output = true; 412 bool uniform_output = true;
277 bool have_error_handler = false; 413 bool have_error_handler = false;
278 std::string err_name; 414 std::string err_name;
279 octave_function *error_handler = 0; 415 octave_function *error_handler = 0;
280 int offset = 1; 416 int offset = 1;
281 int i = 1; 417 int i = 1;
282 OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin); 418 OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin);
419 // This is to prevent copy-on-write.
420 const Cell *cinputs = inputs;
283 421
284 while (i < nargin) 422 while (i < nargin)
285 { 423 {
286 if (args(i).is_string()) 424 if (args(i).is_string())
287 { 425 {
343 } 481 }
344 i++; 482 i++;
345 } 483 }
346 } 484 }
347 485
348 inputlist.resize(nargin-offset); 486 nargin -= offset;
487 inputlist.resize(nargin);
349 488
350 if (have_error_handler) 489 if (have_error_handler)
351 buffer_error_messages++; 490 buffer_error_messages++;
352 491
353 if (uniform_output) 492 if (uniform_output)
354 { 493 {
355 retval.resize(nargout); 494 OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout);
356 495
357 for (octave_idx_type count = 0; count < k ; count++) 496 for (octave_idx_type count = 0; count < k ; count++)
358 { 497 {
359 for (int j = 0; j < nargin-offset; j++) 498 for (int j = 0; j < nargin; j++)
360 inputlist(j) = inputs[j](count); 499 inputlist(j) = cinputs[j](count);
361 500
362 octave_value_list tmp = feval (func, inputlist, nargout); 501 octave_value_list tmp = feval (func, inputlist, nargout);
363 502
364 if (error_state && have_error_handler) 503 if (error_state && have_error_handler)
365 { 504 {
389 528
390 if (count == 0) 529 if (count == 0)
391 { 530 {
392 for (int j = 0; j < nargout; j++) 531 for (int j = 0; j < nargout; j++)
393 { 532 {
394 octave_value val; 533 octave_value val = tmp(j);
395 val = tmp(j); 534
396 535 if (val.numel () == 1)
397 if (error_state) 536 retptr[j].reset (make_col_helper (val, f_args.dims ()));
398 goto cellfun_err; 537 else
399 538 {
400 retval(j) = val.resize(f_args.dims()); 539 error ("cellfun: expecting all values to be scalars for UniformOutput = true");
540 break;
541 }
401 } 542 }
402 } 543 }
403 else 544 else
404 { 545 {
405 idx(0) = octave_value (static_cast<double>(count+1));
406 for (int j = 0; j < nargout; j++) 546 for (int j = 0; j < nargout; j++)
407 { 547 {
408 // FIXME -- need an easier way to express
409 // this test.
410 octave_value val = tmp(j); 548 octave_value val = tmp(j);
411 549
412 if (val.ndims () == 2 550 if (! retptr[j]->collect (count, val))
413 && val.rows () == 1 && val.columns () == 1) 551 {
414 retval(j) = 552 // FIXME: A more elaborate structure would allow again a virtual
415 retval(j).subsasgn ("(", 553 // constructor here.
416 std::list<octave_value_list> 554 retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (),
417 (1, idx(0)), val); 555 f_args.dims ()));
418 else 556 retptr[j]->collect (count, val);
419 error ("cellfun: expecting all values to be scalars for UniformOutput = true"); 557 }
420 } 558 }
421 } 559 }
422 560
423 if (error_state) 561 if (error_state)
424 break; 562 break;
425 } 563 }
564
565 retval.resize (nargout);
566 for (int j = 0; j < nargout; j++)
567 {
568 if (retptr[j].get ())
569 retval(j) = retptr[j]->result ();
570 else
571 retval(j) = Matrix ();
572 }
426 } 573 }
427 else 574 else
428 { 575 {
429 OCTAVE_LOCAL_BUFFER (Cell, results, nargout); 576 OCTAVE_LOCAL_BUFFER (Cell, results, nargout);
430 for (int j = 0; j < nargout; j++) 577 for (int j = 0; j < nargout; j++)
431 results[j].resize(f_args.dims()); 578 results[j].resize(f_args.dims());
432 579
433 for (octave_idx_type count = 0; count < k ; count++) 580 for (octave_idx_type count = 0; count < k ; count++)
434 { 581 {
435 for (int j = 0; j < nargin-offset; j++) 582 for (int j = 0; j < nargin; j++)
436 inputlist(j) = inputs[j](count); 583 inputlist(j) = cinputs[j](count);
437 584
438 octave_value_list tmp = feval (func, inputlist, nargout); 585 octave_value_list tmp = feval (func, inputlist, nargout);
439 586
440 if (error_state && have_error_handler) 587 if (error_state && have_error_handler)
441 { 588 {