Mercurial > octave
view scripts/statistics/mean.m @ 31623:59422a6fbd91 stable
mean.m: Accept large DIM inputs and calculate Matlab-compatible output (bug #63411)
* mean.m: Accept large DIM inputs by removing any dimensions larger than the
dimensions of the input array. Use vectorized operations, rather than for
loop, for this input validation. Pre-declare output array before for loop
for performance. Use ipermute() after calculations to return
correctly-dimensioned result (checked with Matlab). Change BIST tests results
to match correct behavior.
author | Rik <rik@octave.org> |
---|---|
date | Fri, 02 Dec 2022 21:34:15 -0800 |
parents | c154cc05cd1f |
children | 597f3ee61a48 |
line wrap: on
line source
######################################################################## ## ## Copyright (C) 1995-2022 The Octave Project Developers ## ## See the file COPYRIGHT.md in the top-level directory of this ## distribution or <https://octave.org/copyright/>. ## ## This file is part of Octave. ## ## Octave is free software: you can redistribute it and/or modify it ## under the terms of the GNU General Public License as published by ## the Free Software Foundation, either version 3 of the License, or ## (at your option) any later version. ## ## Octave is distributed in the hope that it will be useful, but ## WITHOUT ANY WARRANTY; without even the implied warranty of ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ## GNU General Public License for more details. ## ## You should have received a copy of the GNU General Public License ## along with Octave; see the file COPYING. If not, see ## <https://www.gnu.org/licenses/>. ## ######################################################################## ## -*- texinfo -*- ## @deftypefn {} {@var{y} =} mean (@var{x}) ## @deftypefnx {} {@var{y} =} mean (@var{x}, 'all') ## @deftypefnx {} {@var{y} =} mean (@var{x}, @var{dim}) ## @deftypefnx {} {@var{y} =} mean (@dots{}, '@var{outtype}') ## @deftypefnx {} {@var{y} =} mean (@dots{}, '@var{nanflag}') ## Compute the mean of the elements of @var{x}. ## ## @itemize ## @item ## If @var{x} is a vector, then @code{mean (@var{x})} returns the ## mean of the elements in @var{x} defined as ## @tex ## $$ {\rm mean}(x) = \bar{x} = {1\over N} \sum_{i=1}^N x_i $$ ## where $N$ is the number of elements of @var{x}. ## ## @end tex ## @ifnottex ## ## @example ## mean (@var{x}) = SUM_i @var{x}(i) / N ## @end example ## ## @noindent ## where @math{N} is the number of elements in the @var{x} vector. ## ## @end ifnottex ## ## @item ## If @var{x} is a matrix, then @code{mean} returns a row vector with the mean ## of each column in @var{x}. ## ## @item ## If @var{x} is a multi-dimensional array, then @code{mean} operates along the ## first non-singleton dimension of @var{x}. ## @end itemize ## ## The optional input @var{dim} forces @code{mean} to operate over the ## specified dimension(s). @var{dim} can either be a scalar dimension or a ## vector of non-repeating dimensions. Dimensions must be positive integers, ## and the mean is calculated over the array slice defined by @var{dim}. ## ## Specifying dimension @qcode{"all"} will force @code{mean} to operate on all ## elements of @var{x}, and is equivalent to @code{mean (@var{x}(:))}. ## ## The optional input @var{outtype} specifies the data type that is returned. ## Valid values are: ## ## @table @asis ## @item @qcode{'default'} : Output is of type double, unless the input is ## single in which case the output is of type single. ## ## @item @qcode{'double'} : Output is of type double. ## ## @item @qcode{'native'} : Output is of the same type as the input ## (@code{class (@var{x})}), unless the input is logical in which case the ## output is of type double. ## ## @end table ## ## The optional input @var{nanflag} specifies whether to include/exclude NaN ## values from the calculation. By default, NaN values are included in the ## calculation (@var{nanflag} has the value @qcode{'includenan'}). To exclude ## NaN values, set the value of @var{nanflag} to @qcode{'omitnan'}. ## ## @seealso{median, mode} ## @end deftypefn function y = mean (x, varargin) if (nargin < 1 || nargin > 4 || ! all (cellfun (@ischar, varargin(2:end)))) print_usage (); endif ## Check all char arguments. all_flag = false; omitnan = false; outtype = "default"; for i = 1:numel (varargin) if (ischar (varargin{i})) switch (varargin{i}) case "all" all_flag = true; case "includenan" omitnan = false; case "omitnan" omitnan = true; case {"default", "double", "native"} outtype = varargin{i}; otherwise print_usage (); endswitch endif endfor varargin(cellfun (@ischar, varargin)) = []; if (((numel (varargin) == 1) && ! (isnumeric (varargin{1}))) ... || (numel (varargin) > 1)) print_usage (); endif if (! (isnumeric (x) || islogical (x))) error ("mean: X must be either a numeric or logical vector or matrix"); endif if (numel (varargin) == 0) ## Single numeric input argument, no dimensions given. if (all_flag) n = numel (x(:)); if (omitnan) idx = isnan (x); n -= sum (idx(:)); x(idx) = 0; endif y = sum (x(:), 1) ./ n; else sz = size (x); ## Find the first non-singleton dimension. (dim = find (sz != 1, 1)) || (dim = 1); n = size (x, dim); if (omitnan) idx = isnan (x); n = sum (! idx, dim); x(idx) = 0; endif y = sum (x, dim) ./ n; endif else ## Two numeric input arguments, dimensions given. Note scalar is vector! dim = varargin{1}; if (! (isvector (dim) && all (dim > 0) && all (rem (dim, 1) == 0))) error ("mean: DIM must be a positive integer scalar or vector"); endif if (isscalar (dim)) n = size (x, dim); if (omitnan) idx = isnan (x); n = sum (! idx, dim); x(idx) = 0; endif y = sum (x, dim) ./ n; else sz = size (x); ndims = numel (sz); misdim = [1:ndims]; dim(dim > ndims) = []; # weed out dimensions larger than array misdim(dim) = []; # remove dims asked for leaving missing dims switch (numel (misdim)) ## if all dimensions are given, compute x(:) case 0 n = numel (x(:)); if (omitnan) idx = isnan (x); n -= sum (idx(:)); x(idx) = 0; endif y = sum (x(:), 1) ./ n; ## for 1 dimension left, return column vector case 1 x = permute (x, [misdim, dim]); y = zeros (size (x, 1), 1, "like", x); for i = 1:size (x, 1) x_vec = x(i,:)(:); if (omitnan) x_vec = x_vec(! isnan (x_vec)); endif y(i) = sum (x_vec, 1) ./ numel (x_vec); endfor y = ipermute (y, [misdim, dim]); ## for 2 dimensions left, return matrix case 2 x = permute (x, [misdim, dim]); y = zeros (size (x, 1), size (x, 2), "like", x); for i = 1:size (x, 1) for j = 1:size (x, 2) x_vec = x(i,j,:)(:); if (omitnan) x_vec = x_vec(! isnan (x_vec)); endif y(i,j) = sum (x_vec, 1) ./ numel (x_vec); endfor endfor y = ipermute (y, [misdim, dim]); ## for more than 2 dimensions left, throw error otherwise error ("DIM must index at least N-2 dimensions of X"); endswitch endif endif ## Convert output if requested switch (outtype) case "default" ## do nothing, the operators already do the right thing. case "double" y = double (y); case "native" if (! islogical (x)) y = cast (y, class (x)); endif otherwise ## FIXME: This is unreachable code. Valid values already ## checked in input validation. error ("mean: OUTTYPE '%s' not recognized", outtype); endswitch endfunction %!test %! x = -10:10; %! y = x'; %! z = [y, y+10]; %! assert (mean (x), 0); %! assert (mean (y), 0); %! assert (mean (z), [0, 10]); %!assert (mean (magic (3), 1), [5, 5, 5]) %!assert (mean (magic (3), 2), [5; 5; 5]) %!assert (mean (logical ([1 0 1 1])), 0.75) %!assert (mean (single ([1 0 1 1])), single (0.75)) %!assert (mean ([1 2], 3), [1 2]) ## Test outtype option %!test %! in = [1 2 3]; %! out = 2; %! assert (mean (in, "default"), mean (in)); %! assert (mean (in, "default"), out); %! %! in = single ([1 2 3]); %! out = 2; %! assert (mean (in, "default"), mean (in)); %! assert (mean (in, "default"), single (out)); %! assert (mean (in, "double"), out); %! assert (mean (in, "native"), single (out)); %! %! in = uint8 ([1 2 3]); %! out = 2; %! assert (mean (in, "default"), mean (in)); %! assert (mean (in, "default"), out); %! assert (mean (in, "double"), out); %! assert (mean (in, "native"), uint8 (out)); %! %! in = logical ([1 0 1]); %! out = 2/3; %! assert (mean (in, "default"), mean (in)); %! assert (mean (in, "default"), out); %! assert (mean (in, "native"), out); # logical ignores native option ## Test single input and optional arguments "all", DIM, "omitnan") %!test %! x = [-10:10]; %! y = [x;x+5;x-5]; %! assert (mean (x), 0); %! assert (mean (y, 2), [0, 5, -5]'); %! assert (mean (y, "all"), 0); %! y(2,4) = NaN; %! assert (mean (y', "omitnan"), [0 5.35 -5]); %! z = y + 20; %! assert (mean (z, "all"), NaN); %! m = [20 NaN 15]; %! assert (mean (z'), m); %! assert (mean (z', "includenan"), m); %! m = [20 25.35 15]; %! assert (mean (z', "omitnan"), m); %! assert (mean (z, 2, "omitnan"), m'); %! assert (mean (z, 2, "native", "omitnan"), m'); %! assert (mean (z, 2, "omitnan", "native"), m'); ## Test boolean input %!test %! assert (mean (true, "all"), 1); %! assert (mean (false), 0); %! assert (mean ([true false true]), 2/3, 4e-14); %! assert (mean ([true false true], 1), [1 0 1]); %! assert (mean ([true false NaN], 1), [1 0 NaN]); %! assert (mean ([true false NaN], 2), NaN); %! assert (mean ([true false NaN], 2, "omitnan"), 0.5); %! assert (mean ([true false NaN], 2, "omitnan", "native"), 0.5); ## Test dimension indexing with vecdim in N-dimensional arrays %!test %! x = repmat ([1:20;6:25], [5 2 6 3]); %! assert (size (mean (x, [3 2])), [10 1 1 3]); %! assert (size (mean (x, [1 2])), [1 1 6 3]); %! assert (size (mean (x, [1 2 4])), [1 1 6]); %! assert (size (mean (x, [1 4 3])), [1 40]); %! assert (size (mean (x, [1 2 3 4])), [1 1]); ## Test results with vecdim in N-dimensional arrays and "omitnan" %!test %! x = repmat ([1:20;6:25], [5 2 6 3]); %! m = repmat ([10.5;15.5], [5 1 1 3]); %! assert (mean (x, [3 2]), m, 4e-14); %! x(2,5,6,3) = NaN; %! m(2,3) = NaN; %! assert (mean (x, [3 2]), m, 4e-14); %! m(2,3) = 15.52301255230125; %! assert (mean (x, [3 2], "omitnan"), m, 4e-14); ## Test input validation %!error <Invalid call> mean () %!error <Invalid call> mean (1, 2, 3) %!error <Invalid call> mean (1, 2, 3, 4, 5) %!error <Invalid call> mean (1, "all", 3) %!error <Invalid call> mean (1, "b") %!error <Invalid call> mean (1, 1, "foo") %!error <X must be either a numeric or logical> mean ({1:5}) %!error <X must be either a numeric or logical> mean ("char") %!error <DIM must be a positive integer> mean (1, ones (2,2)) %!error <DIM must be a positive integer> mean (1, 1.5) %!error <DIM must be a positive integer> mean (1, -1) %!error <DIM must be a positive integer> mean (1, -1.5) %!error <DIM must be a positive integer> mean (1, 0) %!error <DIM must be a positive integer> mean (1, NaN) %!error <DIM must be a positive integer> mean (1, Inf) %!error <DIM must index at least N-2 dimensions of X> %! mean (repmat ([1:20;6:25], [5 2 6 3 5]), [1 2])