changeset 31623:59422a6fbd91 stable

mean.m: Accept large DIM inputs and calculate Matlab-compatible output (bug #63411) * mean.m: Accept large DIM inputs by removing any dimensions larger than the dimensions of the input array. Use vectorized operations, rather than for loop, for this input validation. Pre-declare output array before for loop for performance. Use ipermute() after calculations to return correctly-dimensioned result (checked with Matlab). Change BIST tests results to match correct behavior.
author Rik <rik@octave.org>
date Fri, 02 Dec 2022 21:34:15 -0800
parents 535492e34f8f
children 1e270beb6982 ea0b06534a37
files scripts/statistics/mean.m
diffstat 1 files changed, 18 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/statistics/mean.m	Fri Dec 02 19:33:07 2022 -0800
+++ b/scripts/statistics/mean.m	Fri Dec 02 21:34:15 2022 -0800
@@ -177,10 +177,8 @@
       ndims = numel (sz);
       misdim = [1:ndims];
 
-      ## keep remaining dimensions
-      for i = 1:numel (dim)
-        misdim(misdim == dim(i)) = [];
-      endfor
+      dim(dim > ndims) = [];  # weed out dimensions larger than array
+      misdim(dim) = [];       # remove dims asked for leaving missing dims
 
       switch (numel (misdim))
         ## if all dimensions are given, compute x(:)
@@ -196,26 +194,30 @@
         ## for 1 dimension left, return column vector
         case 1
           x = permute (x, [misdim, dim]);
+          y = zeros (size (x, 1), 1, "like", x);
           for i = 1:size (x, 1)
-            x_vec = x(i,:,:,:,:,:,:)(:);
+            x_vec = x(i,:)(:);
             if (omitnan)
               x_vec = x_vec(! isnan (x_vec));
             endif
             y(i) = sum (x_vec, 1) ./ numel (x_vec);
           endfor
+          y = ipermute (y, [misdim, dim]);
 
         ## for 2 dimensions left, return matrix
         case 2
           x = permute (x, [misdim, dim]);
+          y = zeros (size (x, 1), size (x, 2), "like", x);
           for i = 1:size (x, 1)
             for j = 1:size (x, 2)
-              x_vec = x(i,j,:,:,:,:,:)(:);
+              x_vec = x(i,j,:)(:);
               if (omitnan)
                 x_vec = x_vec(! isnan (x_vec));
               endif
               y(i,j) = sum (x_vec, 1) ./ numel (x_vec);
             endfor
           endfor
+          y = ipermute (y, [misdim, dim]);
 
         ## for more than 2 dimensions left, throw error
         otherwise
@@ -319,16 +321,16 @@
 ## Test dimension indexing with vecdim in N-dimensional arrays
 %!test
 %! x = repmat ([1:20;6:25], [5 2 6 3]);
-%! assert (size (mean (x, [3 2])), [10 3]);
-%! assert (size (mean (x, [1 2])), [6 3]);
-%! assert (size (mean (x, [1 2 4])), [1 6]);
+%! assert (size (mean (x, [3 2])), [10 1 1 3]);
+%! assert (size (mean (x, [1 2])), [1 1 6 3]);
+%! assert (size (mean (x, [1 2 4])), [1 1 6]);
 %! assert (size (mean (x, [1 4 3])), [1 40]);
 %! assert (size (mean (x, [1 2 3 4])), [1 1]);
 
 ## Test results with vecdim in N-dimensional arrays and "omitnan"
 %!test
 %! x = repmat ([1:20;6:25], [5 2 6 3]);
-%! m = repmat ([10.5;15.5], [5,3]);
+%! m = repmat ([10.5;15.5], [5 1 1 3]);
 %! assert (mean (x, [3 2]), m, 4e-14);
 %! x(2,5,6,3) = NaN;
 %! m(2,3) = NaN;
@@ -337,12 +339,12 @@
 %! assert (mean (x, [3 2], "omitnan"), m, 4e-14);
 
 ## Test input validation
-%!error <Invalid call to mean.  Correct usage is> mean ()
-%!error <Invalid call to mean.  Correct usage is> mean (1, 2, 3)
-%!error <Invalid call to mean.  Correct usage is> mean (1, 2, 3, 4, 5)
-%!error <Invalid call to mean.  Correct usage is> mean (1, "all", 3)
-%!error <Invalid call to mean.  Correct usage is> mean (1, "b")
-%!error <Invalid call to mean.  Correct usage is> mean (1, 1, "foo")
+%!error <Invalid call> mean ()
+%!error <Invalid call> mean (1, 2, 3)
+%!error <Invalid call> mean (1, 2, 3, 4, 5)
+%!error <Invalid call> mean (1, "all", 3)
+%!error <Invalid call> mean (1, "b")
+%!error <Invalid call> mean (1, 1, "foo")
 %!error <X must be either a numeric or logical> mean ({1:5})
 %!error <X must be either a numeric or logical> mean ("char")
 %!error <DIM must be a positive integer> mean (1, ones (2,2))