changeset 32364:43d010974a89

betainf.m: Improve integer input logic and validation (bug #64726) * betainf.m: Add shortcut function return for integer x inputs. Add x = 0 or 1 case to trivial input handling. Move tail input string check to beginning of function and replace repeat checks with state variable. Add numeric type input validation test and BISTs. Add BISTS to check inputs with mix of 0, 1, and other numbers for x for both tail options.
author Nicholas R. Jankowski <jankowski.nicholas@gmail.com>
date Fri, 29 Sep 2023 19:05:21 -0400
parents 446e747cd7d9
children e3c66ad99652
files scripts/specfun/betainc.m
diffstat 1 files changed, 48 insertions(+), 25 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/specfun/betainc.m	Fri Sep 29 18:00:43 2023 +0200
+++ b/scripts/specfun/betainc.m	Fri Sep 29 19:05:21 2023 -0400
@@ -87,12 +87,25 @@
     error ("betainc: all inputs must be real");
   endif
 
+  if ! (isnumeric (x) && isnumeric (a) && isnumeric (b))
+    error ("betainc: all inputs must be numeric");
+  endif
+
   ## Remember original shape of data, but convert to column vector for calcs.
   orig_sz = size (x);
   x = x(:);
   a = a(:);
   b = b(:);
 
+  switch (tolower (tail))
+    case "lower"
+      lower_upper_flag = true;
+    case "upper"
+      lower_upper_flag = false;
+    otherwise
+      error ("betainc: invalid value for TAIL");
+  endswitch
+
   if (any ((x < 0) | (x > 1)))
     error ("betainc: X must be in the range [0, 1]");
   endif
@@ -105,6 +118,25 @@
     error ("betainc: B must be strictly positive");
   endif
 
+  ## Initialize output array with output class matching x, shortcut trivial
+  ## case of integer class x.
+  if (isinteger (x))
+    ## For x = 0 or 1, the output always reduces to 0 or 1.  Input validation
+    ## ensures all interger inputs must be 0 or 1.
+    I = x;
+    return
+  else
+    I = zeros (size (x), class (x));
+  endif
+
+  ## Convert a,b to floating point if necessary.
+  if (isinteger (a))
+    a = double (a);
+  endif
+  if (isinteger (b))
+    b = double (b);
+  endif
+
   ## If any of the arguments is single then the output should be as well.
   if (strcmp (class (x), "single") || strcmp (class (a), "single")
       || strcmp (class (b), "single"))
@@ -113,38 +145,26 @@
     x = single (x);
   endif
 
-  ## Convert to floating point if necessary
-  if (isinteger (x))
-    I = double (x);
-  endif
-  if (isinteger (a))
-    a = double (a);
-  endif
-  if (isinteger (b))
-    b = double (b);
-  endif
-
-  ## Initialize output array
-  I = zeros (size (x), class (x));
 
   ## Trivial cases (long code here trades memory for speed)
+  x_trivial = (x == 0 | x == 1);
   a_one = (a == 1);
   b_one = (b == 1);
-  a_b_one = a_one & b_one;
+  a_b_x_triv = (a_one & b_one) | x_trivial;
   a_not_one = ! a_one;
   b_not_one = ! b_one;
-  non_trivial = a_not_one & b_not_one;
-  a_one &= b_not_one;
-  b_one &= a_not_one;
+  non_trivial = a_not_one & b_not_one & ! x_trivial;
+  a_one &= b_not_one | x_trivial;
+  b_one &= a_not_one | x_trivial;
 
-  if (strcmpi (tail, "lower"))
-    I(a_b_one) = x(a_b_one);
+  if (lower_upper_flag)
+    I(a_b_x_triv) = x(a_b_x_triv);
     ## See bug #62329.
     ## equivalent to "1 - (1 - x(a_one)) .^ b(a_one)", but less roundoff error
     I(a_one) = - expm1 (log1p (- x(a_one)) .* b(a_one));
     I(b_one) = x(b_one) .^ a(b_one);
-  elseif (strcmpi (tail, "upper"))
-    I(a_b_one) = 1 - x(a_b_one);
+  else
+    I(a_b_x_triv) = 1 - x(a_b_x_triv);
     ## equivalent to "(1 - x(a_one)) .^ b(a_one)", but less roundoff error
     I(a_one) = exp (log1p (- x(a_one)) .* b(a_one));
     ## equivalent to "1 - x(b_one) .^ a(b_one)", but less roundoff error
@@ -161,16 +181,14 @@
   a = a(non_trivial);
   b = b(non_trivial);
 
-  if (strcmpi (tail, "lower"))
+  if (lower_upper_flag)
     fflag = (x > a./(a + b));
     x(fflag) = 1 - x(fflag);
     [a(fflag), b(fflag)] = deal (b(fflag), a(fflag));
-  elseif (strcmpi (tail, "upper"))
+  else
     fflag = (x < (a ./ (a + b)));
     x(! fflag) = 1 - x(! fflag);
     [a(! fflag), b(! fflag)] = deal (b(! fflag), a(! fflag));
-  else
-    error ("betainc: invalid value for TAIL");
   endif
 
   f = zeros (size (x), class (x));
@@ -251,6 +269,8 @@
 %! assert (betainc (1, a, b), ones (20));
 %! assert (betainc (0, a, b, "upper"), ones (20));
 %! assert (betainc (1, a, b, "upper"), zeros (20));
+%! assert (betainc ([0 0.5 1], 2, 2), [0 0.5 1], eps);
+%! assert (betainc ([0 0.5 1], 2, 2, "upper"), [1 0.5 0], eps);
 
 %!test <*34405>
 %! assert (betainc (NaN, 1, 2), NaN);
@@ -270,6 +290,9 @@
 %!error <all inputs must be real> betainc (0.5i, 1, 2)
 %!error <all inputs must be real> betainc (0, 1i, 1)
 %!error <all inputs must be real> betainc (0, 1, 1i)
+%!error <all inputs must be numeric> betainc (char (1), 1, 2)
+%!error <all inputs must be numeric> betainc (0, char (1), 1)
+%!error <all inputs must be numeric> betainc (0, 1, char (1))
 %!error <X must be in the range \[0, 1\]> betainc (-0.1,1,1)
 %!error <X must be in the range \[0, 1\]> betainc (1.1,1,1)
 %!error <X must be in the range \[0, 1\]>