changeset 32182:37e184a83cf4

corr: Avoid error with row vector inputs (bug #64395) * scripts/statistics/corr.m: Add check for row vector inputs that need different handling than provided by the updated cov function and ensure they return an appropriately sized NaN. Add BISTs to verify behavior. * scripts/statistics/cov.m: Add BISTs to verify compatible handling of row and column vector inputs.
author Nicholas R. Jankowski <jankowski.nicholas@gmail.com>
date Sun, 09 Jul 2023 18:13:49 -0400
parents ca72944d16a5
children e4e7bc93f5f7
files scripts/statistics/corr.m scripts/statistics/cov.m
diffstat 2 files changed, 55 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/statistics/corr.m	Fri Jul 07 13:37:43 2023 -0400
+++ b/scripts/statistics/corr.m	Sun Jul 09 18:13:49 2023 -0400
@@ -55,7 +55,7 @@
     print_usage ();
   endif
 
-  ## Input validation is done by cov.m.  Don't repeat tests here
+  ## Most input validation is done by cov.m.  Don't repeat tests here
 
   ## Special case, scalar is always 100% correlated with itself
   if (isscalar (x))
@@ -81,19 +81,57 @@
     ##        information is nonideal.  Consider implementing a more
     ##        efficient cov here as a subfunction to corr.
 
-    nx = columns(x);
+    ## Check for equal input sizes. cov ignores 2-D vector orientation
+    ## but corr does not.
+
+    if (! size_equal (x, y))
+      error ("corr: inputs must be the same size");
+    endif
+
+    nx = columns (x);
     c = cov ([x, y]);
+
+    ## Special handling for row vectors. std=0 along dim 1 and /0 will return
+    ## NaN, but cov will process along vector dimension. Keep special handling
+    ## after call to cov so it handles all other input validation and avoid
+    ## duplicating validation overhead for all other cases.
+    if (isrow (x))
+      if (isa (x, "single") || isa (y, "single"))
+        r = NaN (nx, "single");
+      else
+        r = NaN (nx);
+      endif
+      return;
+    endif
+
     c = c(1:nx, nx+1:end);
-    s = std (x)' * std (y);
+    s = std (x, [], 1)' * std (y, [], 1);
     r = c ./ s;
+
   else
     c = cov (x);
+
+    if isrow(x)  # Special handling for row vector.
+      if (isa (x, "single"))
+        r = NaN (nx, "single");
+      else
+        r = NaN (columns (x));
+      endif
+      return;
+    endif
+
     s = sqrt (diag (c));
     r = c ./ (s * s');
   endif
 
 endfunction
 
+%!test <*64395>
+%! x = [1, 2, 3];
+%! assert (corr (x), NaN (3));
+%! assert (corr (x'), 1, eps);
+%! assert (corr (x, x), NaN (3));
+%! assert (corr (x', x'), 1, eps);
 
 %!test
 %! x = rand (10);
@@ -121,6 +159,9 @@
 
 ## Test input validation
 %!error <Invalid call> corr ()
-%!error corr ([1; 2], ["A", "B"])
+%!error corr ([1; 2], ["A"; "B"])
 %!error corr (ones (2,2,2))
-%!error corr (ones (2,2), ones (2,2,2))
+%!error <inputs must be the same size> corr (ones (2,2), ones (2,2,2))
+%!error <inputs must be the same size> corr ([1,2,3], [1,2,3]')
+%!error <inputs must be the same size> corr ([1,2,3]', [1,2,3])
+
--- a/scripts/statistics/cov.m	Fri Jul 07 13:37:43 2023 -0400
+++ b/scripts/statistics/cov.m	Sun Jul 09 18:13:49 2023 -0400
@@ -329,6 +329,15 @@
 %! assert (isscalar (c));
 %! assert (c, 6);
 
+%!test <*64395>
+%! x = [1, 2, 3];
+%! assert (cov (x), 1, eps);
+%! assert (cov (x'), 1, eps);
+%! assert (cov (x, x), ones (2), eps);
+%! assert (cov (x', x), ones (2), eps);
+%! assert (cov (x, x'), ones (2), eps);
+%! assert (cov (x', x'), ones (2), eps);
+
 %!test
 %! x = [1 0; 1 0];
 %! y = [1 2; 1 1];