# HG changeset patch # User Jaroslav Hajek # Date 1237374406 -3600 # Node ID a8d30dc1beecd363afa8c8654a76f00f10b2b7fa # Parent 6769599e345841d77b8b0123f17c856b471ad76d cellfun optimizations diff -r 6769599e3458 -r a8d30dc1beec src/ChangeLog --- a/src/ChangeLog Tue Mar 17 22:21:54 2009 +0100 +++ b/src/ChangeLog Wed Mar 18 12:06:46 2009 +0100 @@ -1,3 +1,12 @@ +2009-03-18 Jaroslav Hajek + + * DLD-FUNCTIONS/cellfun.cc + (scalar_col_helper, scalar_col_helper_def, scalar_col_helper_nda, + scalar_query_helper): New classes. + (make_col_helper): New function. + (Fcellfun): Use col_helper classes for UniformOutput. Avoid copying by + using const variables. + 2009-03-17 Jaroslav Hajek * OPERATORS/op-b-bm.cc: Add missing INSTALL_BINOPs. diff -r 6769599e3458 -r a8d30dc1beec src/DLD-FUNCTIONS/cellfun.cc --- a/src/DLD-FUNCTIONS/cellfun.cc Tue Mar 17 22:21:54 2009 +0100 +++ b/src/DLD-FUNCTIONS/cellfun.cc Wed Mar 18 12:06:46 2009 +0100 @@ -28,6 +28,7 @@ #include #include #include +#include #include "lo-mappers.h" #include "oct-locbuf.h" @@ -40,6 +41,142 @@ #include "ov-colon.h" #include "unwind-prot.h" +// Rationale: +// The octave_base_value::subsasgn method carries too much overhead for +// per-element assignment strategy. +// This class will optimize the most optimistic and most likely case +// when the output really is scalar by defining a hierarchy of virtual +// collectors specialized for some scalar types. + +class scalar_col_helper +{ +public: + virtual bool collect (octave_idx_type i, const octave_value& val) = 0; + virtual octave_value result (void) = 0; + virtual ~scalar_col_helper (void) { } +}; + +// The default collector represents what was previously done in the main loop. +// This reuses the existing assignment machinery via octave_value::subsasgn, +// which can perform all sorts of conversions, but is relatively slow. + +class scalar_col_helper_def : public scalar_col_helper +{ + std::list idx_list; + octave_value resval; +public: + scalar_col_helper_def (const octave_value& val, const dim_vector& dims) + : idx_list (1), resval (val) + { + idx_list.front ().resize (1); + if (resval.dims () != dims) + resval.resize (dims); + } + ~scalar_col_helper_def (void) { } + + bool collect (octave_idx_type i, const octave_value& val) + { + if (val.numel () == 1) + { + idx_list.front ()(0) = static_cast (i + 1); + resval = resval.subsasgn ("(", idx_list, val); + } + else + error ("cellfun: expecting all values to be scalars for UniformOutput = true"); + + return true; + } + octave_value result (void) + { + return resval; + } +}; + +template +struct scalar_query_helper { }; + +#define DEF_QUERY_HELPER(T, TEST, QUERY) \ +template <> \ +struct scalar_query_helper \ +{ \ + static bool has_value (const octave_value& val) \ + { return TEST; } \ + static T get_value (const octave_value& val) \ + { return QUERY; } \ +} + +DEF_QUERY_HELPER (double, val.is_real_scalar (), val.scalar_value ()); +DEF_QUERY_HELPER (Complex, val.is_complex_scalar (), val.complex_value ()); +DEF_QUERY_HELPER (float, val.is_single_type () && val.is_real_scalar (), + val.float_scalar_value ()); +DEF_QUERY_HELPER (FloatComplex, val.is_single_type () && val.is_complex_scalar (), + val.float_complex_value ()); +DEF_QUERY_HELPER (bool, val.is_bool_scalar (), val.bool_value ()); +// FIXME: More? + +// This specializes for collecting elements of a single type, by accessing +// an array directly. If the scalar is not valid, it returns false. + +template +class scalar_col_helper_nda : public scalar_col_helper +{ + NDA arrayval; + typedef typename NDA::element_type T; +public: + scalar_col_helper_nda (const octave_value& val, const dim_vector& dims) + : arrayval (dims) + { + arrayval(0) = scalar_query_helper::get_value (val); + } + ~scalar_col_helper_nda (void) { } + + bool collect (octave_idx_type i, const octave_value& val) + { + bool retval = scalar_query_helper::has_value (val); + if (retval) + arrayval(i) = scalar_query_helper::get_value (val); + return retval; + } + octave_value result (void) + { + return arrayval; + } +}; + +template class scalar_col_helper_nda; +template class scalar_col_helper_nda; +template class scalar_col_helper_nda; +template class scalar_col_helper_nda; +template class scalar_col_helper_nda; + +// the virtual constructor. +scalar_col_helper * +make_col_helper (const octave_value& val, const dim_vector& dims) +{ + scalar_col_helper *retval; + + if (val.is_bool_scalar ()) + retval = new scalar_col_helper_nda (val, dims); + else if (val.is_complex_scalar ()) + { + if (val.is_single_type ()) + retval = new scalar_col_helper_nda (val, dims); + else + retval = new scalar_col_helper_nda (val, dims); + } + else if (val.is_real_scalar ()) + { + if (val.is_single_type ()) + retval = new scalar_col_helper_nda (val, dims); + else + retval = new scalar_col_helper_nda (val, dims); + } + else + retval = new scalar_col_helper_def (val, dims); + + return retval; +} + DEFUN_DLD (cellfun, args, nargout, "-*- texinfo -*-\n\ @deftypefn {Loadable Function} {} cellfun (@var{name}, @var{c})\n\ @@ -164,7 +301,7 @@ return retval; } - Cell f_args = args(1).cell_value (); + const Cell f_args = args(1).cell_value (); octave_idx_type k = f_args.numel (); @@ -203,7 +340,7 @@ result(count) = static_cast (f_args.elem(count).ndims ()); retval(0) = result; } - else if (name == "prodofsize") + else if (name == "prodofsize" || name == "numel") { NDArray result (f_args.dims ()); for (octave_idx_type count = 0; count < k ; count++) @@ -271,7 +408,6 @@ error ("unknown function"); else { - octave_value_list idx; octave_value_list inputlist; bool uniform_output = true; bool have_error_handler = false; @@ -280,6 +416,8 @@ int offset = 1; int i = 1; OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin); + // This is to prevent copy-on-write. + const Cell *cinputs = inputs; while (i < nargin) { @@ -345,19 +483,20 @@ } } - inputlist.resize(nargin-offset); + nargin -= offset; + inputlist.resize(nargin); if (have_error_handler) buffer_error_messages++; if (uniform_output) { - retval.resize(nargout); + OCTAVE_LOCAL_BUFFER (std::auto_ptr, retptr, nargout); for (octave_idx_type count = 0; count < k ; count++) { - for (int j = 0; j < nargin-offset; j++) - inputlist(j) = inputs[j](count); + for (int j = 0; j < nargin; j++) + inputlist(j) = cinputs[j](count); octave_value_list tmp = feval (func, inputlist, nargout); @@ -391,38 +530,46 @@ { for (int j = 0; j < nargout; j++) { - octave_value val; - val = tmp(j); + octave_value val = tmp(j); - if (error_state) - goto cellfun_err; - - retval(j) = val.resize(f_args.dims()); + if (val.numel () == 1) + retptr[j].reset (make_col_helper (val, f_args.dims ())); + else + { + error ("cellfun: expecting all values to be scalars for UniformOutput = true"); + break; + } } } else { - idx(0) = octave_value (static_cast(count+1)); for (int j = 0; j < nargout; j++) { - // FIXME -- need an easier way to express - // this test. octave_value val = tmp(j); - if (val.ndims () == 2 - && val.rows () == 1 && val.columns () == 1) - retval(j) = - retval(j).subsasgn ("(", - std::list - (1, idx(0)), val); - else - error ("cellfun: expecting all values to be scalars for UniformOutput = true"); - } + if (! retptr[j]->collect (count, val)) + { + // FIXME: A more elaborate structure would allow again a virtual + // constructor here. + retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), + f_args.dims ())); + retptr[j]->collect (count, val); + } + } } if (error_state) break; } + + retval.resize (nargout); + for (int j = 0; j < nargout; j++) + { + if (retptr[j].get ()) + retval(j) = retptr[j]->result (); + else + retval(j) = Matrix (); + } } else { @@ -432,8 +579,8 @@ for (octave_idx_type count = 0; count < k ; count++) { - for (int j = 0; j < nargin-offset; j++) - inputlist(j) = inputs[j](count); + for (int j = 0; j < nargin; j++) + inputlist(j) = cinputs[j](count); octave_value_list tmp = feval (func, inputlist, nargout); diff -r 6769599e3458 -r a8d30dc1beec src/ov-base-mat.h --- a/src/ov-base-mat.h Tue Mar 17 22:21:54 2009 +0100 +++ b/src/ov-base-mat.h Wed Mar 18 12:06:46 2009 +0100 @@ -98,6 +98,8 @@ dim_vector dims (void) const { return matrix.dims (); } + octave_idx_type numel (void) const { return matrix.numel (); } + octave_idx_type nnz (void) const { return matrix.nnz (); } octave_value reshape (const dim_vector& new_dims) const