# HG changeset patch # User Jaroslav Hajek # Date 1267788690 -3600 # Node ID a0b51ac0f88a3573049cc750a45909af6012e51e # Parent aeb5b1e4797882cb99e705190a7c29bbfe2a34f0 optimize accumdim with summation diff -r aeb5b1e47978 -r a0b51ac0f88a liboctave/ChangeLog --- a/liboctave/ChangeLog Fri Mar 05 11:11:01 2010 +0100 +++ b/liboctave/ChangeLog Fri Mar 05 12:31:30 2010 +0100 @@ -1,3 +1,8 @@ +2010-03-05 Jaroslav Hajek + + * MArray.cc (MArray::idx_add_nd): New method. + * MArray.h: Declare it. + 2010-03-04 Jaroslav Hajek * lo-specfun.cc (erfcx, erfcx_impl): New functions. diff -r aeb5b1e47978 -r a0b51ac0f88a liboctave/MArray.cc --- a/liboctave/MArray.cc Fri Mar 05 11:11:01 2010 +0100 +++ b/liboctave/MArray.cc Fri Mar 05 12:31:30 2010 +0100 @@ -135,6 +135,70 @@ idx.loop (len, _idxbinop_helper (this->fortran_vec (), vals.data ())); } +#include + +template +void MArray::idx_add_nd (const idx_vector& idx, const MArray& vals, int dim) +{ + int nd = std::max (this->ndims (), vals.ndims ()); + if (dim < 0) + dim = vals.dims ().first_non_singleton (); + else if (dim > nd) + nd = dim; + + // Check dimensions. + dim_vector ddv = Array::dims ().redim (nd); + dim_vector sdv = vals.dims ().redim (nd); + + octave_idx_type ext = idx.extent (ddv (dim)); + + if (ext > ddv(dim)) + { + ddv(dim) = ext; + Array::resize (ddv); + ext = ddv(dim); + } + + octave_idx_type l,n,u,ns; + get_extent_triplet (ddv, dim, l, n, u); + ns = sdv(dim); + + sdv(dim) = ddv(dim) = 0; + if (ddv != sdv) + (*current_liboctave_error_handler) + ("accumdim: dimension mismatch"); + + T *dst = Array::fortran_vec (); + const T *src = vals.data (); + octave_idx_type len = idx.length (ns); + + if (l == 1) + { + for (octave_idx_type j = 0; j < u; j++) + { + octave_quit (); + + idx.loop (len, _idxadda_helper (dst + j*n, src + j*ns)); + } + } + else + { + for (octave_idx_type j = 0; j < u; j++) + { + octave_quit (); + for (octave_idx_type i = 0; i < len; i++) + { + octave_idx_type k = idx(i); + + mx_inline_add2 (l, dst + l*k, src + l*i); + } + + dst += l*n; + src += l*ns; + } + } +} + // N-dimensional array with math ops. template void diff -r aeb5b1e47978 -r a0b51ac0f88a liboctave/MArray.h --- a/liboctave/MArray.h Fri Mar 05 11:11:01 2010 +0100 +++ b/liboctave/MArray.h Fri Mar 05 12:31:30 2010 +0100 @@ -96,6 +96,8 @@ void idx_max (const idx_vector& idx, const MArray& vals); + void idx_add_nd (const idx_vector& idx, const MArray& vals, int dim = -1); + void changesign (void); }; diff -r aeb5b1e47978 -r a0b51ac0f88a scripts/ChangeLog --- a/scripts/ChangeLog Fri Mar 05 11:11:01 2010 +0100 +++ b/scripts/ChangeLog Fri Mar 05 12:31:30 2010 +0100 @@ -1,3 +1,7 @@ +2010-03-05 Jaroslav Hajek + + * general/accumdim.m: Optimize the summation case. + 2010-03-05 Jaroslav Hajek * general/accumdim.m: New function. diff -r aeb5b1e47978 -r a0b51ac0f88a scripts/general/accumdim.m --- a/scripts/general/accumdim.m Fri Mar 05 11:11:01 2010 +0100 +++ b/scripts/general/accumdim.m Fri Mar 05 12:31:30 2010 +0100 @@ -63,10 +63,6 @@ fillval = 0; endif - if (isempty (func)) - func = @sum; - endif - if (! isvector (subs)) error ("accumdim: subs must be a subscript vector"); elseif (! isindex (subs)) # creates index cache @@ -80,16 +76,33 @@ endif endif - ## The general case. sz = size (val); if (nargin < 3) [~, dim] = max (sz != 1); # first non-singleton dim - elseif (! isindex (dim, ndims (val))) + elseif (! isindex (dim)) error ("accumdim: dim must be a valid dimension"); + elseif (dim > length (sz)) + sz(end+1:dim) = 1; endif sz(dim) = n; + if (isempty (func) || func == @sum) + ## Fast summation case. + A = __accumdim_sum__ (subs, val, dim, n); + + ## Fill in nonzero fill value + if (fillval != 0) + mask = true (n, 1); + mask(subs) = false; + subsc = {':'}(ones (1, length (sz))); + subsc{dim} = mask; + A(subsc{:}) = fillval; + endif + return + endif + + ## The general case. ns = length (subs); ## Sort indices. [subs, idx] = sort (subs(:)); diff -r aeb5b1e47978 -r a0b51ac0f88a src/ChangeLog --- a/src/ChangeLog Fri Mar 05 11:11:01 2010 +0100 +++ b/src/ChangeLog Fri Mar 05 12:31:30 2010 +0100 @@ -1,3 +1,8 @@ +2010-03-05 Jaroslav Hajek + + * data.cc (do_accumdim_sum): New helper function. + (F__accumdim_sum__): New DEFUN. + 2010-03-04 Jaroslav Hajek * ov-base.h (unary_mapper_t::umap_erfcx): New enum member. diff -r aeb5b1e47978 -r a0b51ac0f88a src/data.cc --- a/src/data.cc Fri Mar 05 11:11:01 2010 +0100 +++ b/src/data.cc Fri Mar 05 12:31:30 2010 +0100 @@ -6439,6 +6439,78 @@ } template +static NDT +do_accumdim_sum (const idx_vector& idx, const NDT& vals, + int dim = -1, octave_idx_type n = -1) +{ + typedef typename NDT::element_type T; + if (n < 0) + n = idx.extent (0); + else if (idx.extent (n) > n) + error ("accumarray: index out of range"); + + dim_vector rdv = vals.dims (); + if (dim < 0) + dim = vals.dims ().first_non_singleton (); + else if (dim >= rdv.length ()) + rdv.resize (dim+1, 1); + + rdv(dim) = n; + + NDT retval (rdv, T()); + + retval.idx_add_nd (idx, vals, dim); + return retval; +} + +DEFUN (__accumdim_sum__, args, , + "-*- texinfo -*-\n\ +@deftypefn {Built-in Function} {} __accumdim_sum__ (@var{idx}, @var{vals}, @var{dim}, @var{n})\n\ +Undocumented internal function.\n\ +@end deftypefn") +{ + octave_value retval; + int nargin = args.length (); + if (nargin >= 2 && nargin <= 4 && args(0).is_numeric_type ()) + { + idx_vector idx = args(0).index_vector (); + int dim = -1; + if (nargin >= 3) + dim = args(2).int_value () - 1; + + octave_idx_type n = -1; + if (nargin == 4) + n = args(3).idx_type_value (true); + + if (! error_state) + { + octave_value vals = args(1); + + if (vals.is_single_type ()) + { + if (vals.is_complex_type ()) + retval = do_accumdim_sum (idx, vals.float_complex_array_value (), dim, n); + else + retval = do_accumdim_sum (idx, vals.float_array_value (), dim, n); + } + else if (vals.is_numeric_type () || vals.is_bool_type ()) + { + if (vals.is_complex_type ()) + retval = do_accumdim_sum (idx, vals.complex_array_value (), dim, n); + else + retval = do_accumdim_sum (idx, vals.array_value (), dim, n); + } + else + gripe_wrong_type_arg ("accumdim", vals); + } + } + else + print_usage (); + + return retval; +} + +template static NDT do_merge (const Array& mask, const NDT& tval, const NDT& fval)