comparison scripts/statistics/mean.m @ 31623:59422a6fbd91 stable

mean.m: Accept large DIM inputs and calculate Matlab-compatible output (bug #63411) * mean.m: Accept large DIM inputs by removing any dimensions larger than the dimensions of the input array. Use vectorized operations, rather than for loop, for this input validation. Pre-declare output array before for loop for performance. Use ipermute() after calculations to return correctly-dimensioned result (checked with Matlab). Change BIST tests results to match correct behavior.
author Rik <rik@octave.org>
date Fri, 02 Dec 2022 21:34:15 -0800
parents c154cc05cd1f
children 597f3ee61a48
comparison
equal deleted inserted replaced
31611:535492e34f8f 31623:59422a6fbd91
175 175
176 sz = size (x); 176 sz = size (x);
177 ndims = numel (sz); 177 ndims = numel (sz);
178 misdim = [1:ndims]; 178 misdim = [1:ndims];
179 179
180 ## keep remaining dimensions 180 dim(dim > ndims) = []; # weed out dimensions larger than array
181 for i = 1:numel (dim) 181 misdim(dim) = []; # remove dims asked for leaving missing dims
182 misdim(misdim == dim(i)) = [];
183 endfor
184 182
185 switch (numel (misdim)) 183 switch (numel (misdim))
186 ## if all dimensions are given, compute x(:) 184 ## if all dimensions are given, compute x(:)
187 case 0 185 case 0
188 n = numel (x(:)); 186 n = numel (x(:));
194 y = sum (x(:), 1) ./ n; 192 y = sum (x(:), 1) ./ n;
195 193
196 ## for 1 dimension left, return column vector 194 ## for 1 dimension left, return column vector
197 case 1 195 case 1
198 x = permute (x, [misdim, dim]); 196 x = permute (x, [misdim, dim]);
197 y = zeros (size (x, 1), 1, "like", x);
199 for i = 1:size (x, 1) 198 for i = 1:size (x, 1)
200 x_vec = x(i,:,:,:,:,:,:)(:); 199 x_vec = x(i,:)(:);
201 if (omitnan) 200 if (omitnan)
202 x_vec = x_vec(! isnan (x_vec)); 201 x_vec = x_vec(! isnan (x_vec));
203 endif 202 endif
204 y(i) = sum (x_vec, 1) ./ numel (x_vec); 203 y(i) = sum (x_vec, 1) ./ numel (x_vec);
205 endfor 204 endfor
205 y = ipermute (y, [misdim, dim]);
206 206
207 ## for 2 dimensions left, return matrix 207 ## for 2 dimensions left, return matrix
208 case 2 208 case 2
209 x = permute (x, [misdim, dim]); 209 x = permute (x, [misdim, dim]);
210 y = zeros (size (x, 1), size (x, 2), "like", x);
210 for i = 1:size (x, 1) 211 for i = 1:size (x, 1)
211 for j = 1:size (x, 2) 212 for j = 1:size (x, 2)
212 x_vec = x(i,j,:,:,:,:,:)(:); 213 x_vec = x(i,j,:)(:);
213 if (omitnan) 214 if (omitnan)
214 x_vec = x_vec(! isnan (x_vec)); 215 x_vec = x_vec(! isnan (x_vec));
215 endif 216 endif
216 y(i,j) = sum (x_vec, 1) ./ numel (x_vec); 217 y(i,j) = sum (x_vec, 1) ./ numel (x_vec);
217 endfor 218 endfor
218 endfor 219 endfor
220 y = ipermute (y, [misdim, dim]);
219 221
220 ## for more than 2 dimensions left, throw error 222 ## for more than 2 dimensions left, throw error
221 otherwise 223 otherwise
222 error ("DIM must index at least N-2 dimensions of X"); 224 error ("DIM must index at least N-2 dimensions of X");
223 endswitch 225 endswitch
317 %! assert (mean ([true false NaN], 2, "omitnan", "native"), 0.5); 319 %! assert (mean ([true false NaN], 2, "omitnan", "native"), 0.5);
318 320
319 ## Test dimension indexing with vecdim in N-dimensional arrays 321 ## Test dimension indexing with vecdim in N-dimensional arrays
320 %!test 322 %!test
321 %! x = repmat ([1:20;6:25], [5 2 6 3]); 323 %! x = repmat ([1:20;6:25], [5 2 6 3]);
322 %! assert (size (mean (x, [3 2])), [10 3]); 324 %! assert (size (mean (x, [3 2])), [10 1 1 3]);
323 %! assert (size (mean (x, [1 2])), [6 3]); 325 %! assert (size (mean (x, [1 2])), [1 1 6 3]);
324 %! assert (size (mean (x, [1 2 4])), [1 6]); 326 %! assert (size (mean (x, [1 2 4])), [1 1 6]);
325 %! assert (size (mean (x, [1 4 3])), [1 40]); 327 %! assert (size (mean (x, [1 4 3])), [1 40]);
326 %! assert (size (mean (x, [1 2 3 4])), [1 1]); 328 %! assert (size (mean (x, [1 2 3 4])), [1 1]);
327 329
328 ## Test results with vecdim in N-dimensional arrays and "omitnan" 330 ## Test results with vecdim in N-dimensional arrays and "omitnan"
329 %!test 331 %!test
330 %! x = repmat ([1:20;6:25], [5 2 6 3]); 332 %! x = repmat ([1:20;6:25], [5 2 6 3]);
331 %! m = repmat ([10.5;15.5], [5,3]); 333 %! m = repmat ([10.5;15.5], [5 1 1 3]);
332 %! assert (mean (x, [3 2]), m, 4e-14); 334 %! assert (mean (x, [3 2]), m, 4e-14);
333 %! x(2,5,6,3) = NaN; 335 %! x(2,5,6,3) = NaN;
334 %! m(2,3) = NaN; 336 %! m(2,3) = NaN;
335 %! assert (mean (x, [3 2]), m, 4e-14); 337 %! assert (mean (x, [3 2]), m, 4e-14);
336 %! m(2,3) = 15.52301255230125; 338 %! m(2,3) = 15.52301255230125;
337 %! assert (mean (x, [3 2], "omitnan"), m, 4e-14); 339 %! assert (mean (x, [3 2], "omitnan"), m, 4e-14);
338 340
339 ## Test input validation 341 ## Test input validation
340 %!error <Invalid call to mean. Correct usage is> mean () 342 %!error <Invalid call> mean ()
341 %!error <Invalid call to mean. Correct usage is> mean (1, 2, 3) 343 %!error <Invalid call> mean (1, 2, 3)
342 %!error <Invalid call to mean. Correct usage is> mean (1, 2, 3, 4, 5) 344 %!error <Invalid call> mean (1, 2, 3, 4, 5)
343 %!error <Invalid call to mean. Correct usage is> mean (1, "all", 3) 345 %!error <Invalid call> mean (1, "all", 3)
344 %!error <Invalid call to mean. Correct usage is> mean (1, "b") 346 %!error <Invalid call> mean (1, "b")
345 %!error <Invalid call to mean. Correct usage is> mean (1, 1, "foo") 347 %!error <Invalid call> mean (1, 1, "foo")
346 %!error <X must be either a numeric or logical> mean ({1:5}) 348 %!error <X must be either a numeric or logical> mean ({1:5})
347 %!error <X must be either a numeric or logical> mean ("char") 349 %!error <X must be either a numeric or logical> mean ("char")
348 %!error <DIM must be a positive integer> mean (1, ones (2,2)) 350 %!error <DIM must be a positive integer> mean (1, ones (2,2))
349 %!error <DIM must be a positive integer> mean (1, 1.5) 351 %!error <DIM must be a positive integer> mean (1, 1.5)
350 %!error <DIM must be a positive integer> mean (1, -1) 352 %!error <DIM must be a positive integer> mean (1, -1)