Mercurial > octave
comparison scripts/statistics/mean.m @ 31623:59422a6fbd91 stable
mean.m: Accept large DIM inputs and calculate Matlab-compatible output (bug #63411)
* mean.m: Accept large DIM inputs by removing any dimensions larger than the
dimensions of the input array. Use vectorized operations, rather than for
loop, for this input validation. Pre-declare output array before for loop
for performance. Use ipermute() after calculations to return
correctly-dimensioned result (checked with Matlab). Change BIST tests results
to match correct behavior.
author | Rik <rik@octave.org> |
---|---|
date | Fri, 02 Dec 2022 21:34:15 -0800 |
parents | c154cc05cd1f |
children | 597f3ee61a48 |
comparison
equal
deleted
inserted
replaced
31611:535492e34f8f | 31623:59422a6fbd91 |
---|---|
175 | 175 |
176 sz = size (x); | 176 sz = size (x); |
177 ndims = numel (sz); | 177 ndims = numel (sz); |
178 misdim = [1:ndims]; | 178 misdim = [1:ndims]; |
179 | 179 |
180 ## keep remaining dimensions | 180 dim(dim > ndims) = []; # weed out dimensions larger than array |
181 for i = 1:numel (dim) | 181 misdim(dim) = []; # remove dims asked for leaving missing dims |
182 misdim(misdim == dim(i)) = []; | |
183 endfor | |
184 | 182 |
185 switch (numel (misdim)) | 183 switch (numel (misdim)) |
186 ## if all dimensions are given, compute x(:) | 184 ## if all dimensions are given, compute x(:) |
187 case 0 | 185 case 0 |
188 n = numel (x(:)); | 186 n = numel (x(:)); |
194 y = sum (x(:), 1) ./ n; | 192 y = sum (x(:), 1) ./ n; |
195 | 193 |
196 ## for 1 dimension left, return column vector | 194 ## for 1 dimension left, return column vector |
197 case 1 | 195 case 1 |
198 x = permute (x, [misdim, dim]); | 196 x = permute (x, [misdim, dim]); |
197 y = zeros (size (x, 1), 1, "like", x); | |
199 for i = 1:size (x, 1) | 198 for i = 1:size (x, 1) |
200 x_vec = x(i,:,:,:,:,:,:)(:); | 199 x_vec = x(i,:)(:); |
201 if (omitnan) | 200 if (omitnan) |
202 x_vec = x_vec(! isnan (x_vec)); | 201 x_vec = x_vec(! isnan (x_vec)); |
203 endif | 202 endif |
204 y(i) = sum (x_vec, 1) ./ numel (x_vec); | 203 y(i) = sum (x_vec, 1) ./ numel (x_vec); |
205 endfor | 204 endfor |
205 y = ipermute (y, [misdim, dim]); | |
206 | 206 |
207 ## for 2 dimensions left, return matrix | 207 ## for 2 dimensions left, return matrix |
208 case 2 | 208 case 2 |
209 x = permute (x, [misdim, dim]); | 209 x = permute (x, [misdim, dim]); |
210 y = zeros (size (x, 1), size (x, 2), "like", x); | |
210 for i = 1:size (x, 1) | 211 for i = 1:size (x, 1) |
211 for j = 1:size (x, 2) | 212 for j = 1:size (x, 2) |
212 x_vec = x(i,j,:,:,:,:,:)(:); | 213 x_vec = x(i,j,:)(:); |
213 if (omitnan) | 214 if (omitnan) |
214 x_vec = x_vec(! isnan (x_vec)); | 215 x_vec = x_vec(! isnan (x_vec)); |
215 endif | 216 endif |
216 y(i,j) = sum (x_vec, 1) ./ numel (x_vec); | 217 y(i,j) = sum (x_vec, 1) ./ numel (x_vec); |
217 endfor | 218 endfor |
218 endfor | 219 endfor |
220 y = ipermute (y, [misdim, dim]); | |
219 | 221 |
220 ## for more than 2 dimensions left, throw error | 222 ## for more than 2 dimensions left, throw error |
221 otherwise | 223 otherwise |
222 error ("DIM must index at least N-2 dimensions of X"); | 224 error ("DIM must index at least N-2 dimensions of X"); |
223 endswitch | 225 endswitch |
317 %! assert (mean ([true false NaN], 2, "omitnan", "native"), 0.5); | 319 %! assert (mean ([true false NaN], 2, "omitnan", "native"), 0.5); |
318 | 320 |
319 ## Test dimension indexing with vecdim in N-dimensional arrays | 321 ## Test dimension indexing with vecdim in N-dimensional arrays |
320 %!test | 322 %!test |
321 %! x = repmat ([1:20;6:25], [5 2 6 3]); | 323 %! x = repmat ([1:20;6:25], [5 2 6 3]); |
322 %! assert (size (mean (x, [3 2])), [10 3]); | 324 %! assert (size (mean (x, [3 2])), [10 1 1 3]); |
323 %! assert (size (mean (x, [1 2])), [6 3]); | 325 %! assert (size (mean (x, [1 2])), [1 1 6 3]); |
324 %! assert (size (mean (x, [1 2 4])), [1 6]); | 326 %! assert (size (mean (x, [1 2 4])), [1 1 6]); |
325 %! assert (size (mean (x, [1 4 3])), [1 40]); | 327 %! assert (size (mean (x, [1 4 3])), [1 40]); |
326 %! assert (size (mean (x, [1 2 3 4])), [1 1]); | 328 %! assert (size (mean (x, [1 2 3 4])), [1 1]); |
327 | 329 |
328 ## Test results with vecdim in N-dimensional arrays and "omitnan" | 330 ## Test results with vecdim in N-dimensional arrays and "omitnan" |
329 %!test | 331 %!test |
330 %! x = repmat ([1:20;6:25], [5 2 6 3]); | 332 %! x = repmat ([1:20;6:25], [5 2 6 3]); |
331 %! m = repmat ([10.5;15.5], [5,3]); | 333 %! m = repmat ([10.5;15.5], [5 1 1 3]); |
332 %! assert (mean (x, [3 2]), m, 4e-14); | 334 %! assert (mean (x, [3 2]), m, 4e-14); |
333 %! x(2,5,6,3) = NaN; | 335 %! x(2,5,6,3) = NaN; |
334 %! m(2,3) = NaN; | 336 %! m(2,3) = NaN; |
335 %! assert (mean (x, [3 2]), m, 4e-14); | 337 %! assert (mean (x, [3 2]), m, 4e-14); |
336 %! m(2,3) = 15.52301255230125; | 338 %! m(2,3) = 15.52301255230125; |
337 %! assert (mean (x, [3 2], "omitnan"), m, 4e-14); | 339 %! assert (mean (x, [3 2], "omitnan"), m, 4e-14); |
338 | 340 |
339 ## Test input validation | 341 ## Test input validation |
340 %!error <Invalid call to mean. Correct usage is> mean () | 342 %!error <Invalid call> mean () |
341 %!error <Invalid call to mean. Correct usage is> mean (1, 2, 3) | 343 %!error <Invalid call> mean (1, 2, 3) |
342 %!error <Invalid call to mean. Correct usage is> mean (1, 2, 3, 4, 5) | 344 %!error <Invalid call> mean (1, 2, 3, 4, 5) |
343 %!error <Invalid call to mean. Correct usage is> mean (1, "all", 3) | 345 %!error <Invalid call> mean (1, "all", 3) |
344 %!error <Invalid call to mean. Correct usage is> mean (1, "b") | 346 %!error <Invalid call> mean (1, "b") |
345 %!error <Invalid call to mean. Correct usage is> mean (1, 1, "foo") | 347 %!error <Invalid call> mean (1, 1, "foo") |
346 %!error <X must be either a numeric or logical> mean ({1:5}) | 348 %!error <X must be either a numeric or logical> mean ({1:5}) |
347 %!error <X must be either a numeric or logical> mean ("char") | 349 %!error <X must be either a numeric or logical> mean ("char") |
348 %!error <DIM must be a positive integer> mean (1, ones (2,2)) | 350 %!error <DIM must be a positive integer> mean (1, ones (2,2)) |
349 %!error <DIM must be a positive integer> mean (1, 1.5) | 351 %!error <DIM must be a positive integer> mean (1, 1.5) |
350 %!error <DIM must be a positive integer> mean (1, -1) | 352 %!error <DIM must be a positive integer> mean (1, -1) |