Mercurial > octave
changeset 24869:170e8625562a
mean.m: add out_type option to control data type of return value
author | Carnë Draug <carandraug@octave.org> |
---|---|
date | Tue, 13 Mar 2018 14:42:29 +0100 |
parents | 441b27c0fd5e |
children | ca43264971ea |
files | scripts/statistics/mean.m |
diffstat | 1 files changed, 125 insertions(+), 53 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/statistics/mean.m Mon Mar 12 21:06:13 2018 -0400 +++ b/scripts/statistics/mean.m Tue Mar 13 14:42:29 2018 +0100 @@ -21,6 +21,7 @@ ## @deftypefnx {} {} mean (@var{x}, @var{dim}) ## @deftypefnx {} {} mean (@var{x}, @var{opt}) ## @deftypefnx {} {} mean (@var{x}, @var{dim}, @var{opt}) +## @deftypefnx {} {} mean (@dots{}, @var{outtype}) ## Compute the mean of the elements of the vector @var{x}. ## ## The mean is defined as @@ -58,6 +59,23 @@ ## Compute the harmonic mean. ## @end table ## +## The optional argument @var{outtype} selects the data type of the +## output value. The following options are recognized: +## +## @table @asis +## @item @qcode{"default"} +## Output will be of class double unless @var{x} is of class single, +## in which case the output will also be single. +## +## @item @qcode{"double"} +## Output will be of class double. +## +## @item @qcode{"native"} +## Output will be the same class as @var{x} unless @var{x} is of class +## logical in which case it returns of class double. +## +## @end table +## ## Both @var{dim} and @var{opt} are optional. If both are supplied, either ## may appear first. ## @seealso{median, mode} @@ -66,70 +84,96 @@ ## Author: KH <Kurt.Hornik@wu-wien.ac.at> ## Description: Compute arithmetic, geometric, and harmonic mean -function y = mean (x, opt1, opt2) +function y = mean (x, varargin) - if (nargin < 1 || nargin > 3) + if (nargin < 1 || nargin > 4) print_usage (); endif if (! (isnumeric (x) || islogical (x))) error ("mean: X must be a numeric vector or matrix"); endif + nd = ndims (x); + sz = size (x); - need_dim = false; + ## We support too many options... + + ## If OUTTYPE is set, it must be the last option. If DIM and + ## MEAN_TYPE exist, they must be the first two options - if (nargin == 1) - opt = "a"; - need_dim = true; - elseif (nargin == 2) - if (ischar (opt1)) - opt = opt1; - need_dim = true; - else - dim = opt1; - opt = "a"; + out_type = "default"; + if (numel (varargin)) + maybe_out_type = tolower (varargin{end}); + if (any (strcmpi (maybe_out_type, {"default", "double", "native"}))) + out_type = maybe_out_type; + varargin(end) = []; endif - elseif (nargin == 3) - if (ischar (opt1)) - opt = opt1; - dim = opt2; - elseif (ischar (opt2)) - opt = opt2; - dim = opt1; - else - error ("mean: OPT must be a string"); - endif - else + endif + + scalars = cellfun (@isscalar, varargin); + chars = cellfun (@ischar, varargin); + numerics = cellfun (@isnumeric, varargin); + + dim_mask = numerics & scalars; + mean_type_mask = chars & scalars; + if (! all (dim_mask | mean_type_mask)) print_usage (); endif - nd = ndims (x); - sz = size (x); - if (need_dim) - ## Find the first non-singleton dimension. - (dim = find (sz > 1, 1)) || (dim = 1); - else - if (! (isscalar (dim) && dim == fix (dim) && dim > 0)) - error ("mean: DIM must be an integer and a valid dimension"); - endif - endif + switch (nnz (dim_mask)) + case 0 # Find the first non-singleton dimension + (dim = find (sz > 1, 1)) || (dim = 1); + case 1 + dim = varargin{dim_mask}; + if (dim != fix (dim) || dim < 1) + error ("mean: DIM must be an integer and a valid dimension"); + endif + otherwise + print_usage (); + endswitch - n = size (x, dim); + switch (nnz (mean_type_mask)) + case 0 + mean_type = "a"; + case 1 + mean_type = varargin{mean_type_mask}; + otherwise + print_usage (); + endswitch - if (strcmp (opt, "a")) - y = sum (x, dim) / n; - elseif (strcmp (opt, "g")) - if (all (x(:) >= 0)) - y = exp (sum (log (x), dim) ./ n); - else - error ("mean: X must not contain any negative values"); - endif - elseif (strcmp (opt, "h")) - y = n ./ sum (1 ./ x, dim); - else - error ("mean: option '%s' not recognized", opt); - endif + ## The actual mean computation + n = size (x, dim); + switch (mean_type) + case "a" + y = sum (x, dim) / n; + case "g" + if (all (x(:) >= 0)) + y = exp (sum (log (x), dim) ./ n); + else + error ("mean: X must not contain any negative values"); + endif + case "h" + y = n ./ sum (1 ./ x, dim); + otherwise + error ("mean: mean type '%s' not recognized", mean_type); + endswitch + ## Convert output as requested + switch (out_type) + case "default" + ## do nothing, the operators already do the right thing + case "double" + y = double (y); + case "native" + if (islogical (x)) + ## ignore it, return double anyway + else + y = cast (y, class (x)); + endif + otherwise + ## this should have been filtered out during input check, but... + error ("mean: OUTTYPE '%s' not recognized", out_type); + endswitch endfunction @@ -153,12 +197,40 @@ %!assert (mean ([1 2], 3), [1 2]) ## Test input validation -%!error mean () -%!error mean (1, 2, 3, 4) +%!error <Invalid call to mean. Correct usage is> mean () +%!error <Invalid call to mean. Correct usage is> mean (1, 2, 3, 4) %!error <X must be a numeric> mean ({1:5}) -%!error <OPT must be a string> mean (1, 2, 3) -%!error <DIM must be an integer> mean (1, ones (2,2)) +%!error <Invalid call to mean. Correct usage is> mean (1, 2, 3) +%!error <Invalid call to mean. Correct usage is> mean (1, ones (2,2)) %!error <DIM must be an integer> mean (1, 1.5) %!error <DIM must be .* a valid dimension> mean (1, 0) %!error <X must not contain any negative values> mean ([1 -1], "g") -%!error <option 'b' not recognized> mean (1, "b") +%!error <mean type 'b' not recognized> mean (1, "b") +%!error <Invalid call to mean. Correct usage is> mean (1, 1, "foo") + +## Test outtype option +%!test +%! in = [1 2 3]; +%! out = 2; +%! assert (mean (in, "default"), mean (in)) +%! assert (mean (in, "default"), out) +%! +%! in = single ([1 2 3]); +%! out = 2; +%! assert (mean (in, "default"), mean (in)) +%! assert (mean (in, "default"), single (out)) +%! assert (mean (in, "double"), out) +%! assert (mean (in, "native"), single (out)) +%! +%! in = uint8 ([1 2 3]); +%! out = 2; +%! assert (mean (in, "default"), mean (in)) +%! assert (mean (in, "default"), out) +%! assert (mean (in, "double"), out) +%! assert (mean (in, "native"), uint8 (out)) +%! +%! in = logical ([1 0 1]); +%! out = 2/3; +%! assert (mean (in, "default"), mean (in)) +%! assert (mean (in, "default"), out) +%! assert (mean (in, "native"), out) # logical ignores native option