Mercurial > octave
changeset 32182:37e184a83cf4
corr: Avoid error with row vector inputs (bug #64395)
* scripts/statistics/corr.m: Add check for row vector inputs that need
different handling than provided by the updated cov function and ensure
they return an appropriately sized NaN. Add BISTs to verify behavior.
* scripts/statistics/cov.m: Add BISTs to verify compatible handling of row
and column vector inputs.
author | Nicholas R. Jankowski <jankowski.nicholas@gmail.com> |
---|---|
date | Sun, 09 Jul 2023 18:13:49 -0400 |
parents | ca72944d16a5 |
children | e4e7bc93f5f7 |
files | scripts/statistics/corr.m scripts/statistics/cov.m |
diffstat | 2 files changed, 55 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/statistics/corr.m Fri Jul 07 13:37:43 2023 -0400 +++ b/scripts/statistics/corr.m Sun Jul 09 18:13:49 2023 -0400 @@ -55,7 +55,7 @@ print_usage (); endif - ## Input validation is done by cov.m. Don't repeat tests here + ## Most input validation is done by cov.m. Don't repeat tests here ## Special case, scalar is always 100% correlated with itself if (isscalar (x)) @@ -81,19 +81,57 @@ ## information is nonideal. Consider implementing a more ## efficient cov here as a subfunction to corr. - nx = columns(x); + ## Check for equal input sizes. cov ignores 2-D vector orientation + ## but corr does not. + + if (! size_equal (x, y)) + error ("corr: inputs must be the same size"); + endif + + nx = columns (x); c = cov ([x, y]); + + ## Special handling for row vectors. std=0 along dim 1 and /0 will return + ## NaN, but cov will process along vector dimension. Keep special handling + ## after call to cov so it handles all other input validation and avoid + ## duplicating validation overhead for all other cases. + if (isrow (x)) + if (isa (x, "single") || isa (y, "single")) + r = NaN (nx, "single"); + else + r = NaN (nx); + endif + return; + endif + c = c(1:nx, nx+1:end); - s = std (x)' * std (y); + s = std (x, [], 1)' * std (y, [], 1); r = c ./ s; + else c = cov (x); + + if isrow(x) # Special handling for row vector. + if (isa (x, "single")) + r = NaN (nx, "single"); + else + r = NaN (columns (x)); + endif + return; + endif + s = sqrt (diag (c)); r = c ./ (s * s'); endif endfunction +%!test <*64395> +%! x = [1, 2, 3]; +%! assert (corr (x), NaN (3)); +%! assert (corr (x'), 1, eps); +%! assert (corr (x, x), NaN (3)); +%! assert (corr (x', x'), 1, eps); %!test %! x = rand (10); @@ -121,6 +159,9 @@ ## Test input validation %!error <Invalid call> corr () -%!error corr ([1; 2], ["A", "B"]) +%!error corr ([1; 2], ["A"; "B"]) %!error corr (ones (2,2,2)) -%!error corr (ones (2,2), ones (2,2,2)) +%!error <inputs must be the same size> corr (ones (2,2), ones (2,2,2)) +%!error <inputs must be the same size> corr ([1,2,3], [1,2,3]') +%!error <inputs must be the same size> corr ([1,2,3]', [1,2,3]) +
--- a/scripts/statistics/cov.m Fri Jul 07 13:37:43 2023 -0400 +++ b/scripts/statistics/cov.m Sun Jul 09 18:13:49 2023 -0400 @@ -329,6 +329,15 @@ %! assert (isscalar (c)); %! assert (c, 6); +%!test <*64395> +%! x = [1, 2, 3]; +%! assert (cov (x), 1, eps); +%! assert (cov (x'), 1, eps); +%! assert (cov (x, x), ones (2), eps); +%! assert (cov (x', x), ones (2), eps); +%! assert (cov (x, x'), ones (2), eps); +%! assert (cov (x', x'), ones (2), eps); + %!test %! x = [1 0; 1 0]; %! y = [1 2; 1 1];