changeset 31857:3daa1cfe091d

movfun.m: Return correct class of output (bug #63802). * movfun.m: Check whether there are any elements to process based on requested dimension. If there are none, short-circuit execution and return an empty matrix with the same size as the input. Rename variable "tmp" to "fout" for clarity. Declare new variable "yclass" which is the class ouf the output determined by executing "f(x)". Initialize output "y" to be an array of zeros of class "yclass". Add BIST tests for bug #63802. * movfun.m (movfun_oncol): Add new input parameter "yclass" to function. Initialize output "y" to be an array of zeros of class "yclass".
author Rik <rik@octave.org>
date Sun, 26 Feb 2023 19:54:58 -0800
parents c80cf1588ed0
children 84805310d0d7
files scripts/signal/movfun.m
diffstat 1 files changed, 33 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/signal/movfun.m	Sat Feb 25 10:28:39 2023 -0500
+++ b/scripts/signal/movfun.m	Sun Feb 26 19:54:58 2023 -0800
@@ -204,6 +204,13 @@
   endif
 
   N = szx(dim);
+  if (N == 0)
+    ## Nothing to do.  Return immediately with empty output same shape as input.
+    ## Technically, it would be best to return the correct class, rather than
+    ## always "double", but this seems like a lot of work for little gain.
+    y = zeros (szx);
+    return;
+  endif
 
   ## Calculate slicing indices.  This call also validates WLEN input.
   [slc, C, Cpre, Cpos, win] = movslice (N, wlen);
@@ -256,8 +263,9 @@
 
   ## FIXME: Validation doesn't seem to work correctly (noted 12/16/2018).
   ## Validate that outdim makes sense
-  tmp     = fcn (zeros (length (win), 1));  # output for window
-  noutdim = length (tmp);                   # number of output dimensions
+  fout = fcn (zeros (length (win), 1, class (x)));  # output for window
+  yclass = class (fout);                    # record class of fcn output
+  noutdim = length (fout);                  # number of output dimensions
   if (! isempty (outdim))
     if (max (outdim) > noutdim)
       error ("Octave:invalid-input-arg", ...
@@ -275,11 +283,12 @@
     fcn_ = fcn;
   endif
 
+  ## Initialize output array of appropriate size and class.
+  y = zeros (N, ncols, soutdim, yclass);
   ## Apply processing to each column
   ## FIXME: Is it faster with cellfun?  Don't think so, but needs testing.
-  y = zeros (N, ncols, soutdim);
   parfor i = 1:ncols
-    y(:,i,:) = movfun_oncol (fcn_, x(:,i), wlen, bcfcn,
+    y(:,i,:) = movfun_oncol (fcn_, yclass, x(:,i), wlen, bcfcn,
                              slc, C, Cpre, Cpos, win, soutdim);
   endparfor
 
@@ -290,12 +299,12 @@
 
 endfunction
 
-function y = movfun_oncol (fcn, x, wlen, bcfcn, slcidx, C, Cpre, Cpos, win, odim)
+function y = movfun_oncol (fcn, yclass, x, wlen, bcfcn, slcidx, C, Cpre, Cpos, win, odim)
 
   N = length (Cpre) + length (C) + length (Cpos);
-  y = NA (N, odim);
+  y = zeros (N, odim, yclass);
 
-  ## Process center part
+  ## Process center of data
   try
     y(C,:) = fcn (x(slcidx));
   catch err
@@ -306,10 +315,10 @@
 
     ## Try divide and conquer approach with smaller slices of data.
     ## For loops are slow, so don't try too hard with this approach.
-    Nslices = 8;  # configurable
-    idx1 = fix (linspace (1, numel (C), Nslices));
-    idx2 = fix (linspace (1, columns (slcidx), Nslices));
-    for i = 1 : Nslices-1
+    N_SLICES = 8;  # configurable
+    idx1 = fix (linspace (1, numel (C), N_SLICES));
+    idx2 = fix (linspace (1, columns (slcidx), N_SLICES));
+    for i = 1 : N_SLICES-1
       y(C(idx1(i):idx1(i+1)),:) = fcn (x(slcidx(:, idx2(i):idx2(i+1))));
     endfor
   end_try_catch
@@ -617,13 +626,25 @@
 %!assert (movfun (@min, UNO, wlen02, "Endpoints", "same"), UNO)
 %!assert (movfun (@min, UNO, wlen20, "Endpoints", "same"), UNO)
 
-## Multidimensional output
+## Multi-dimensional output
 %!assert (size( movfun (@(x) [min(x), max(x)], (1:10).', 3)), [10 2])
 %!assert (size( movfun (@(x) [min(x), max(x)], cumsum (ones (10,5),2), 3)),
 %!        [10 5 2])
 ## outdim > dim
 %!error movfun (@(x) [min(x), max(x)], (1:10).', 3, "Outdim", 3)
 
+## Test for correct return class based on output of function. 
+%!test <*63802>
+%! x = single (1:10);
+%! y = movfun (@mean, x, 3);
+%! assert (class (y), 'single');
+%! y = movfun (@mean, uint8 (x), 3);
+%! assert (class (y), 'double');
+
+## Test calculation along empty dimension
+%!assert <*63802> (movfun (@mean, zeros (2,0,3, 'uint8'), 3, 'dim', 2),
+%!                 zeros (2,0,3, 'double'))
+
 ## Test input validation
 %!error <Invalid call> movfun ()
 %!error <Invalid call> movfun (@min)