changeset 24869:170e8625562a

mean.m: add out_type option to control data type of return value
author Carnë Draug <carandraug@octave.org>
date Tue, 13 Mar 2018 14:42:29 +0100
parents 441b27c0fd5e
children ca43264971ea
files scripts/statistics/mean.m
diffstat 1 files changed, 125 insertions(+), 53 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/statistics/mean.m	Mon Mar 12 21:06:13 2018 -0400
+++ b/scripts/statistics/mean.m	Tue Mar 13 14:42:29 2018 +0100
@@ -21,6 +21,7 @@
 ## @deftypefnx {} {} mean (@var{x}, @var{dim})
 ## @deftypefnx {} {} mean (@var{x}, @var{opt})
 ## @deftypefnx {} {} mean (@var{x}, @var{dim}, @var{opt})
+## @deftypefnx {} {} mean (@dots{}, @var{outtype})
 ## Compute the mean of the elements of the vector @var{x}.
 ##
 ## The mean is defined as
@@ -58,6 +59,23 @@
 ## Compute the harmonic mean.
 ## @end table
 ##
+## The optional argument @var{outtype} selects the data type of the
+## output value.  The following options are recognized:
+##
+## @table @asis
+## @item @qcode{"default"}
+## Output will be of class double unless @var{x} is of class single,
+## in which case the output will also be single.
+##
+## @item @qcode{"double"}
+## Output will be of class double.
+##
+## @item @qcode{"native"}
+## Output will be the same class as @var{x} unless @var{x} is of class
+## logical in which case it returns of class double.
+##
+## @end table
+##
 ## Both @var{dim} and @var{opt} are optional.  If both are supplied, either
 ## may appear first.
 ## @seealso{median, mode}
@@ -66,70 +84,96 @@
 ## Author: KH <Kurt.Hornik@wu-wien.ac.at>
 ## Description: Compute arithmetic, geometric, and harmonic mean
 
-function y = mean (x, opt1, opt2)
+function y = mean (x, varargin)
 
-  if (nargin < 1 || nargin > 3)
+  if (nargin < 1 || nargin > 4)
     print_usage ();
   endif
 
   if (! (isnumeric (x) || islogical (x)))
     error ("mean: X must be a numeric vector or matrix");
   endif
+  nd = ndims (x);
+  sz = size (x);
 
-  need_dim = false;
+  ## We support too many options...
+
+  ## If OUTTYPE is set, it must be the last option.  If DIM and
+  ## MEAN_TYPE exist, they must be the first two options
 
-  if (nargin == 1)
-    opt = "a";
-    need_dim = true;
-  elseif (nargin == 2)
-    if (ischar (opt1))
-      opt = opt1;
-      need_dim = true;
-    else
-      dim = opt1;
-      opt = "a";
+  out_type = "default";
+  if (numel (varargin))
+    maybe_out_type = tolower (varargin{end});
+    if (any (strcmpi (maybe_out_type, {"default", "double", "native"})))
+      out_type = maybe_out_type;
+      varargin(end) = [];
     endif
-  elseif (nargin == 3)
-    if (ischar (opt1))
-      opt = opt1;
-      dim = opt2;
-    elseif (ischar (opt2))
-      opt = opt2;
-      dim = opt1;
-    else
-      error ("mean: OPT must be a string");
-    endif
-  else
+  endif
+
+  scalars = cellfun (@isscalar, varargin);
+  chars = cellfun (@ischar, varargin);
+  numerics = cellfun (@isnumeric, varargin);
+
+  dim_mask = numerics & scalars;
+  mean_type_mask = chars & scalars;
+  if (! all (dim_mask | mean_type_mask))
     print_usage ();
   endif
 
-  nd = ndims (x);
-  sz = size (x);
-  if (need_dim)
-    ## Find the first non-singleton dimension.
-    (dim = find (sz > 1, 1)) || (dim = 1);
-  else
-    if (! (isscalar (dim) && dim == fix (dim) && dim > 0))
-      error ("mean: DIM must be an integer and a valid dimension");
-    endif
-  endif
+  switch (nnz (dim_mask))
+    case 0 # Find the first non-singleton dimension
+      (dim = find (sz > 1, 1)) || (dim = 1);
+    case 1
+      dim = varargin{dim_mask};
+      if (dim != fix (dim) || dim < 1)
+        error ("mean: DIM must be an integer and a valid dimension");
+      endif
+    otherwise
+      print_usage ();
+  endswitch
 
-  n = size (x, dim);
+  switch (nnz (mean_type_mask))
+    case 0
+      mean_type = "a";
+    case 1
+      mean_type = varargin{mean_type_mask};
+    otherwise
+      print_usage ();
+  endswitch
 
-  if (strcmp (opt, "a"))
-    y = sum (x, dim) / n;
-  elseif (strcmp (opt, "g"))
-    if (all (x(:) >= 0))
-      y = exp (sum (log (x), dim) ./ n);
-    else
-      error ("mean: X must not contain any negative values");
-    endif
-  elseif (strcmp (opt, "h"))
-    y = n ./ sum (1 ./ x, dim);
-  else
-    error ("mean: option '%s' not recognized", opt);
-  endif
+  ## The actual mean computation
+  n = size (x, dim);
+  switch (mean_type)
+    case "a"
+      y = sum (x, dim) / n;
+    case "g"
+      if (all (x(:) >= 0))
+        y = exp (sum (log (x), dim) ./ n);
+      else
+        error ("mean: X must not contain any negative values");
+      endif
+    case "h"
+      y = n ./ sum (1 ./ x, dim);
+    otherwise
+      error ("mean: mean type '%s' not recognized", mean_type);
+  endswitch
 
+  ## Convert output as requested
+  switch (out_type)
+    case "default"
+      ## do nothing, the operators already do the right thing
+    case "double"
+      y = double (y);
+    case "native"
+      if (islogical (x))
+        ## ignore it, return double anyway
+      else
+        y = cast (y, class (x));
+      endif
+    otherwise
+      ## this should have been filtered out during input check, but...
+      error ("mean: OUTTYPE '%s' not recognized", out_type);
+  endswitch
 endfunction
 
 
@@ -153,12 +197,40 @@
 %!assert (mean ([1 2], 3), [1 2])
 
 ## Test input validation
-%!error mean ()
-%!error mean (1, 2, 3, 4)
+%!error <Invalid call to mean.  Correct usage is> mean ()
+%!error <Invalid call to mean.  Correct usage is> mean (1, 2, 3, 4)
 %!error <X must be a numeric> mean ({1:5})
-%!error <OPT must be a string> mean (1, 2, 3)
-%!error <DIM must be an integer> mean (1, ones (2,2))
+%!error <Invalid call to mean.  Correct usage is> mean (1, 2, 3)
+%!error <Invalid call to mean.  Correct usage is> mean (1, ones (2,2))
 %!error <DIM must be an integer> mean (1, 1.5)
 %!error <DIM must be .* a valid dimension> mean (1, 0)
 %!error <X must not contain any negative values> mean ([1 -1], "g")
-%!error <option 'b' not recognized> mean (1, "b")
+%!error <mean type 'b' not recognized> mean (1, "b")
+%!error <Invalid call to mean.  Correct usage is> mean (1, 1, "foo")
+
+## Test outtype option
+%!test
+%! in = [1 2 3];
+%! out = 2;
+%! assert (mean (in, "default"), mean (in))
+%! assert (mean (in, "default"), out)
+%!
+%! in = single ([1 2 3]);
+%! out = 2;
+%! assert (mean (in, "default"), mean (in))
+%! assert (mean (in, "default"), single (out))
+%! assert (mean (in, "double"), out)
+%! assert (mean (in, "native"), single (out))
+%!
+%! in = uint8 ([1 2 3]);
+%! out = 2;
+%! assert (mean (in, "default"), mean (in))
+%! assert (mean (in, "default"), out)
+%! assert (mean (in, "double"), out)
+%! assert (mean (in, "native"), uint8 (out))
+%!
+%! in = logical ([1 0 1]);
+%! out = 2/3;
+%! assert (mean (in, "default"), mean (in))
+%! assert (mean (in, "default"), out)
+%! assert (mean (in, "native"), out) # logical ignores native option