changeset 32408:6a1f89bb969c stable

betainf.m: Improve integer input logic and validation (bug #64726) * betainf.m: Add shortcut function return for integer x inputs. Add x = 0 or 1 case to trivial input handling. Move tail input string check to beginning of function and replace repeat checks with state variable. Add numeric type input validation test and BISTs. Add BISTS to check inputs with mix of 0, 1, and other numbers for x for both tail options.
author Nicholas R. Jankowski <jankowski.nicholas@gmail.com>
date Fri, 13 Oct 2023 23:20:34 -0400
parents 9afc383bb60a
children 88ecbd109776 6d56d47b9c03
files scripts/specfun/betainc.m
diffstat 1 files changed, 48 insertions(+), 25 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/specfun/betainc.m	Fri Oct 13 14:58:46 2023 +0200
+++ b/scripts/specfun/betainc.m	Fri Oct 13 23:20:34 2023 -0400
@@ -87,12 +87,25 @@
     error ("betainc: all inputs must be real");
   endif
 
+  if ! (isnumeric (x) && isnumeric (a) && isnumeric (b))
+    error ("betainc: all inputs must be numeric");
+  endif
+
   ## Remember original shape of data, but convert to column vector for calcs.
   orig_sz = size (x);
   x = x(:);
   a = a(:);
   b = b(:);
 
+  switch (tolower (tail))
+    case "lower"
+      lower_upper_flag = true;
+    case "upper"
+      lower_upper_flag = false;
+    otherwise
+      error ("betainc: invalid value for TAIL");
+  endswitch
+
   if (any ((x < 0) | (x > 1)))
     error ("betainc: X must be in the range [0, 1]");
   endif
@@ -105,6 +118,25 @@
     error ("betainc: B must be strictly positive");
   endif
 
+  ## Initialize output array with output class matching x, shortcut trivial
+  ## case of integer class x.
+  if (isinteger (x))
+    ## For x = 0 or 1, the output always reduces to 0 or 1.  Input validation
+    ## ensures all interger inputs must be 0 or 1.
+    I = x;
+    return
+  else
+    I = zeros (size (x), class (x));
+  endif
+
+  ## Convert a,b to floating point if necessary.
+  if (isinteger (a))
+    a = double (a);
+  endif
+  if (isinteger (b))
+    b = double (b);
+  endif
+
   ## If any of the arguments is single then the output should be as well.
   if (strcmp (class (x), "single") || strcmp (class (a), "single")
       || strcmp (class (b), "single"))
@@ -113,38 +145,26 @@
     x = single (x);
   endif
 
-  ## Convert to floating point if necessary
-  if (isinteger (x))
-    I = double (x);
-  endif
-  if (isinteger (a))
-    a = double (a);
-  endif
-  if (isinteger (b))
-    b = double (b);
-  endif
-
-  ## Initialize output array
-  I = zeros (size (x), class (x));
 
   ## Trivial cases (long code here trades memory for speed)
+  x_trivial = (x == 0 | x == 1);
   a_one = (a == 1);
   b_one = (b == 1);
-  a_b_one = a_one & b_one;
+  a_b_x_triv = (a_one & b_one) | x_trivial;
   a_not_one = ! a_one;
   b_not_one = ! b_one;
-  non_trivial = a_not_one & b_not_one;
-  a_one &= b_not_one;
-  b_one &= a_not_one;
+  non_trivial = a_not_one & b_not_one & ! x_trivial;
+  a_one &= b_not_one | x_trivial;
+  b_one &= a_not_one | x_trivial;
 
-  if (strcmpi (tail, "lower"))
-    I(a_b_one) = x(a_b_one);
+  if (lower_upper_flag)
+    I(a_b_x_triv) = x(a_b_x_triv);
     ## See bug #62329.
     ## equivalent to "1 - (1 - x(a_one)) .^ b(a_one)", but less roundoff error
     I(a_one) = - expm1 (log1p (- x(a_one)) .* b(a_one));
     I(b_one) = x(b_one) .^ a(b_one);
-  elseif (strcmpi (tail, "upper"))
-    I(a_b_one) = 1 - x(a_b_one);
+  else
+    I(a_b_x_triv) = 1 - x(a_b_x_triv);
     ## equivalent to "(1 - x(a_one)) .^ b(a_one)", but less roundoff error
     I(a_one) = exp (log1p (- x(a_one)) .* b(a_one));
     ## equivalent to "1 - x(b_one) .^ a(b_one)", but less roundoff error
@@ -161,16 +181,14 @@
   a = a(non_trivial);
   b = b(non_trivial);
 
-  if (strcmpi (tail, "lower"))
+  if (lower_upper_flag)
     fflag = (x > a./(a + b));
     x(fflag) = 1 - x(fflag);
     [a(fflag), b(fflag)] = deal (b(fflag), a(fflag));
-  elseif (strcmpi (tail, "upper"))
+  else
     fflag = (x < (a ./ (a + b)));
     x(! fflag) = 1 - x(! fflag);
     [a(! fflag), b(! fflag)] = deal (b(! fflag), a(! fflag));
-  else
-    error ("betainc: invalid value for TAIL");
   endif
 
   f = zeros (size (x), class (x));
@@ -251,6 +269,8 @@
 %! assert (betainc (1, a, b), ones (20));
 %! assert (betainc (0, a, b, "upper"), ones (20));
 %! assert (betainc (1, a, b, "upper"), zeros (20));
+%! assert (betainc ([0 0.5 1], 2, 2), [0 0.5 1], eps);
+%! assert (betainc ([0 0.5 1], 2, 2, "upper"), [1 0.5 0], eps);
 
 %!test <*34405>
 %! assert (betainc (NaN, 1, 2), NaN);
@@ -270,6 +290,9 @@
 %!error <all inputs must be real> betainc (0.5i, 1, 2)
 %!error <all inputs must be real> betainc (0, 1i, 1)
 %!error <all inputs must be real> betainc (0, 1, 1i)
+%!error <all inputs must be numeric> betainc (char (1), 1, 2)
+%!error <all inputs must be numeric> betainc (0, char (1), 1)
+%!error <all inputs must be numeric> betainc (0, 1, char (1))
 %!error <X must be in the range \[0, 1\]> betainc (-0.1,1,1)
 %!error <X must be in the range \[0, 1\]> betainc (1.1,1,1)
 %!error <X must be in the range \[0, 1\]>