changeset 32265:4f287cb8002c

corr.m: Avoid producing error for unequal number of columns (bug #64555) * corr.m: Change size_equal check to equal number of rows check for multiple inputs. Add check for ndims > 3 that errors if not paired with a row vector for compatibility. Correct row vector NaN output size to depend on number of columns in both x and y. Add input validation tests for codepaths that cannot be passed to cov. Move call to cov after input validation and special handling for row vectors and nD inputs. Add BISTs covering input size checks, ndims >3, and input validation. Add docstring note about requiring equal rows for x and y. * cov.m: Add islogical test to varagin checking for opt. Add BISTs covering y input validation and cov(x, y, true/false) input forms.
author Nicholas R. Jankowski <jankowski.nicholas@gmail.com>
date Thu, 17 Aug 2023 10:42:07 -0400
parents f0a9f9cc3c15
children e852dad3734f
files scripts/statistics/corr.m scripts/statistics/cov.m
diffstat 2 files changed, 79 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/statistics/corr.m	Wed Aug 16 17:14:24 2023 -0400
+++ b/scripts/statistics/corr.m	Thu Aug 17 10:42:07 2023 -0400
@@ -32,6 +32,7 @@
 ## a variable, then the @w{(@var{i}, @var{j})-th} entry of
 ## @code{corr (@var{x}, @var{y})} is the correlation between the
 ## @var{i}-th variable in @var{x} and the @var{j}-th variable in @var{y}.
+## @var{x} and @var{y} must have the same number of rows (observations).
 ## @tex
 ## $$
 ## {\rm corr}(x,y) = {{\rm cov}(x,y) \over {\rm std}(x) \, {\rm std}(y)}
@@ -55,10 +56,13 @@
     print_usage ();
   endif
 
-  ## Most input validation is done by cov.m.  Don't repeat tests here.
+  if (! (isnumeric (x) || islogical (x)))
+    error ("corr: X must be a numeric vector or matrix");
+  endif
 
   ## No check for division by zero error, which happens only when
   ## there is a constant vector and should be rare.
+
   if (nargin == 2)
     ## Adjust for Octave 9.1.0 compatibility behavior change in two-input cov.
     ## cov now treats cov(x,y) as cov(x(:),y(:)), returning a 2x2 covariance
@@ -72,30 +76,46 @@
     ##        efficient cov here as a subfunction to corr.  At that point,
     ##        input validation will need to be coded back into this function.
 
-    ## Check for equal input sizes.  cov ignores 2-D vector orientation,
-    ## but corr does not.
-    if (! size_equal (x, y))
-      error ("corr: X and Y must be the same size");
+    if (! (isnumeric (y) || islogical (y)))
+      error ("corr: Y must be a numeric vector or matrix");
     endif
 
-    c = cov ([x, y]);  # Also performs input validation of x and y.
-    nc = columns (x);
+    ## Check for equal number of rows before concatenating inputs for cov.
+    ## This will also catch mixed orientation 2-D vectors which cov allows but
+    ## corr should not.
+    if (rows (x) != rows (y))
+      error ("corr: X and Y must have the same number of rows");
+    endif
+
+    rowx = isrow (x);
+    rowy = isrow (y);
+
+    if ((! rowy && ndims (x) > 2) || (! rowx && ndims (y) > 2))
+      ## For compatibility 3D is permitted only if other input is row vector
+      ## which results in NaNs.
+        error (["corr: X and Y must be two dimensional unless the other ", ...
+                "input is a scalar or row vector"]);
+    endif
 
     ## Special handling for row vectors.  std=0 along dim 1 and division by 0
     ## will return NaN, but cov will process along vector dimension.  Keep
     ## special handling after call to cov so it handles all other input
     ## validation and avoid duplicating validation overhead for all other
     ## cases.
-    if (isrow (x))
+    ncx = columns (x);
+    ncy = columns (y);
+    if (rowx || rowy)
       if (isa (x, "single") || isa (y, "single"))
-        r = NaN (nc, "single");
+        r = NaN (ncx, ncy, "single");
       else
-        r = NaN (nc);
+        r = NaN (ncx, ncy);
       endif
       return;
     endif
 
-    c = c(1:nc, nc+1:end);
+    c = cov ([x, y]);  # Also performs input validation of x and y.
+
+    c = c(1:ncx, ncx+1:end);
     s = std (x, [], 1)' * std (y, [], 1);
     r = c ./ s;
 
@@ -144,6 +164,17 @@
 %!assert (corr (5), NaN)
 %!assert (corr (single (5)), single (NaN))
 
+## Special case: constant vectors
+%!assert (corr ([5; 5; 5], [1; 2; 3]), NaN)
+%!assert (corr ([1; 2; 3], [5;5;5]), NaN)
+
+%!test <*64555>
+%! x = [1 2; 3 4; 5 6];
+%! y = [1 2 3]';
+%! assert (corr (x, y), [1; 1]);
+%! assert (corr (y, x), [1, 1]);
+%! assert (corr (x, [y, y]), [1 1; 1 1])
+
 %!test <*64395>
 %! x = [1, 2, 3];
 %! assert (corr (x), NaN (3));
@@ -158,10 +189,27 @@
 %! assert (corr (x, x), single (NaN (3)));
 %! assert (corr (x', x'), 1, single (eps));
 
+%!assert <*64555> (corr (1, rand (1, 10)), NaN (1, 10));
+%!assert <*64555> (corr (rand (1, 10), 1), NaN (10, 1));
+%!assert <*64555> (corr (rand (1, 10), rand (1, 10)), NaN (10, 10));
+%!assert <*64555> (corr (rand (1, 5), rand (1, 10)), NaN (5, 10));
+%!assert <*64555> (corr (5, rand (1, 10, 5)), NaN (1, 10));
+%!assert <*64555> (corr (rand (1, 5, 5), rand (1, 10)), NaN (5, 10));
+%!assert <*64555> (corr (rand (1, 5, 5, 99), rand (1, 10)), NaN (5, 10));
+
 ## Test input validation
 %!error <Invalid call> corr ()
-%!error <X and Y must be the same size> corr (ones (2,2), ones (2,2,2))
-%!error <X and Y must be the same size> corr ([1,2,3], [1,2,3]')
-%!error <X and Y must be the same size> corr ([1,2,3]', [1,2,3])
-%!error corr ([1; 2], ["A"; "B"])
-%!error corr (ones (2,2,2))
+%!error corr (1, 2, 3)
+%!error <X must be a> corr ("foo")
+%!error <X must be a> corr ({123})
+%!error <X must be a> corr (struct())
+%!error <Y must be a> corr (1, "foo")
+%!error <Y must be a> corr (1, {123})
+%!error <Y must be a> corr (1, struct())
+%!error <Y must be a> corr ([1; 2], ["A"; "B"])
+%!error <X and Y must have the same number of rows> corr (ones (2,2), ones (3,2))
+%!error <X and Y must have the same number of rows> corr ([1,2,3], [1,2,3]')
+%!error <X and Y must have the same number of rows> corr ([1,2,3]', [1,2,3])
+%!error <X and Y must have the same number of rows> corr (ones (2,2), ones (1,2,2))
+%!error <X and Y must be two dimensional unless> corr (ones (2,2), ones (2,2,2))
+%!error corr (ones (2,2,2)) # Single input validation handled by corr
--- a/scripts/statistics/cov.m	Wed Aug 16 17:14:24 2023 -0400
+++ b/scripts/statistics/cov.m	Thu Aug 17 10:42:07 2023 -0400
@@ -154,8 +154,9 @@
         if (ischar (varargin{end}))
           nanflag = lower (varargin{end});
 
-          if (isscalar (varargin{1}) && ...
-            (varargin{1} == 0 || varargin{1} == 1))
+          if ((isnumeric (varargin{1}) || islogical (varargin{1})) && ...
+              isscalar (varargin{1}) && ...
+              (varargin{1} == 0 || varargin{1} == 1))
             opt = double (varargin {1});
 
           else
@@ -173,8 +174,9 @@
         if (ischar (varargin{end}))
           nanflag = lower (varargin{end});
 
-        elseif (isscalar (varargin{1}) && ...
-               (varargin{1} == 0 || varargin{1} == 1))
+        elseif ((isnumeric (varargin{1}) || islogical (varargin{1})) && ...
+                isscalar (varargin{1}) && ...
+                (varargin{1} == 0 || varargin{1} == 1))
           opt = double (varargin {1});
 
         else
@@ -388,6 +390,11 @@
 %!assert (cov (0, logical(0)), double(0))
 %!assert (cov (logical(0), 0), double(0))
 %!assert (cov (logical([0 1; 1 0]), logical([0 1; 1 0])), double ([1 1;1 1]./3))
+%!assert (cov ([1 2 3], [3 4 5], 0), [1 1; 1 1])
+%!assert (cov ([1 2 3], [3 4 5], false), [1 1; 1 1])
+%!assert (cov ([1 2 3], [3 4 5], 1), [2/3 2/3; 2/3 2/3], eps)
+%!assert (cov ([1 2 3], [3 4 5], true), [2/3 2/3; 2/3 2/3], eps)
+
 
 ## Test empty and NaN handling (bug #50583)
 %!assert <*50583> (cov ([]), NaN)
@@ -499,14 +506,16 @@
 %!error <X must be a> cov ("foo")
 %!error <X must be a> cov ({123})
 %!error <X must be a> cov (struct())
-%!error <X must be a> cov (ones (2, 2, 2))
-%!error <X must be a> cov (ones (1, 0, 2))
+%!error <X must be a 2-D> cov (ones (2, 2, 2))
+%!error <X must be a 2-D> cov (ones (1, 0, 2))
 %!error <only one NANFLAG> cov (1, "foo", 0, "includenan")
 %!error <only one NANFLAG> cov (1, 1, "foo", "includenan")
 %!error <normalization paramter OPT must be> cov (1, 2, [])
 %!error <normalization paramter OPT must be> cov (1, 2, 1.1)
 %!error <normalization paramter OPT must be> cov (1, 2, -1)
 %!error <normalization paramter OPT must be> cov (1, 2, [0 1])
+%!error <Y must be a> cov (1, {123})
+%!error <Y must be a> cov (1, struct())
 %!error <X and Y must have the same number> cov (5,[1 2])
 %!error <X and Y must have the same number> cov (ones (2, 2), ones (2, 2, 2))
 %!error <X and Y must have the same number> cov (ones (2, 2), ones (3, 2))