Mercurial > octave
changeset 32364:43d010974a89
betainf.m: Improve integer input logic and validation (bug #64726)
* betainf.m: Add shortcut function return for integer x inputs. Add x = 0
or 1 case to trivial input handling. Move tail input string check to
beginning of function and replace repeat checks with state variable. Add
numeric type input validation test and BISTs. Add BISTS to check inputs
with mix of 0, 1, and other numbers for x for both tail options.
author | Nicholas R. Jankowski <jankowski.nicholas@gmail.com> |
---|---|
date | Fri, 29 Sep 2023 19:05:21 -0400 |
parents | 446e747cd7d9 |
children | e3c66ad99652 |
files | scripts/specfun/betainc.m |
diffstat | 1 files changed, 48 insertions(+), 25 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/specfun/betainc.m Fri Sep 29 18:00:43 2023 +0200 +++ b/scripts/specfun/betainc.m Fri Sep 29 19:05:21 2023 -0400 @@ -87,12 +87,25 @@ error ("betainc: all inputs must be real"); endif + if ! (isnumeric (x) && isnumeric (a) && isnumeric (b)) + error ("betainc: all inputs must be numeric"); + endif + ## Remember original shape of data, but convert to column vector for calcs. orig_sz = size (x); x = x(:); a = a(:); b = b(:); + switch (tolower (tail)) + case "lower" + lower_upper_flag = true; + case "upper" + lower_upper_flag = false; + otherwise + error ("betainc: invalid value for TAIL"); + endswitch + if (any ((x < 0) | (x > 1))) error ("betainc: X must be in the range [0, 1]"); endif @@ -105,6 +118,25 @@ error ("betainc: B must be strictly positive"); endif + ## Initialize output array with output class matching x, shortcut trivial + ## case of integer class x. + if (isinteger (x)) + ## For x = 0 or 1, the output always reduces to 0 or 1. Input validation + ## ensures all interger inputs must be 0 or 1. + I = x; + return + else + I = zeros (size (x), class (x)); + endif + + ## Convert a,b to floating point if necessary. + if (isinteger (a)) + a = double (a); + endif + if (isinteger (b)) + b = double (b); + endif + ## If any of the arguments is single then the output should be as well. if (strcmp (class (x), "single") || strcmp (class (a), "single") || strcmp (class (b), "single")) @@ -113,38 +145,26 @@ x = single (x); endif - ## Convert to floating point if necessary - if (isinteger (x)) - I = double (x); - endif - if (isinteger (a)) - a = double (a); - endif - if (isinteger (b)) - b = double (b); - endif - - ## Initialize output array - I = zeros (size (x), class (x)); ## Trivial cases (long code here trades memory for speed) + x_trivial = (x == 0 | x == 1); a_one = (a == 1); b_one = (b == 1); - a_b_one = a_one & b_one; + a_b_x_triv = (a_one & b_one) | x_trivial; a_not_one = ! a_one; b_not_one = ! b_one; - non_trivial = a_not_one & b_not_one; - a_one &= b_not_one; - b_one &= a_not_one; + non_trivial = a_not_one & b_not_one & ! x_trivial; + a_one &= b_not_one | x_trivial; + b_one &= a_not_one | x_trivial; - if (strcmpi (tail, "lower")) - I(a_b_one) = x(a_b_one); + if (lower_upper_flag) + I(a_b_x_triv) = x(a_b_x_triv); ## See bug #62329. ## equivalent to "1 - (1 - x(a_one)) .^ b(a_one)", but less roundoff error I(a_one) = - expm1 (log1p (- x(a_one)) .* b(a_one)); I(b_one) = x(b_one) .^ a(b_one); - elseif (strcmpi (tail, "upper")) - I(a_b_one) = 1 - x(a_b_one); + else + I(a_b_x_triv) = 1 - x(a_b_x_triv); ## equivalent to "(1 - x(a_one)) .^ b(a_one)", but less roundoff error I(a_one) = exp (log1p (- x(a_one)) .* b(a_one)); ## equivalent to "1 - x(b_one) .^ a(b_one)", but less roundoff error @@ -161,16 +181,14 @@ a = a(non_trivial); b = b(non_trivial); - if (strcmpi (tail, "lower")) + if (lower_upper_flag) fflag = (x > a./(a + b)); x(fflag) = 1 - x(fflag); [a(fflag), b(fflag)] = deal (b(fflag), a(fflag)); - elseif (strcmpi (tail, "upper")) + else fflag = (x < (a ./ (a + b))); x(! fflag) = 1 - x(! fflag); [a(! fflag), b(! fflag)] = deal (b(! fflag), a(! fflag)); - else - error ("betainc: invalid value for TAIL"); endif f = zeros (size (x), class (x)); @@ -251,6 +269,8 @@ %! assert (betainc (1, a, b), ones (20)); %! assert (betainc (0, a, b, "upper"), ones (20)); %! assert (betainc (1, a, b, "upper"), zeros (20)); +%! assert (betainc ([0 0.5 1], 2, 2), [0 0.5 1], eps); +%! assert (betainc ([0 0.5 1], 2, 2, "upper"), [1 0.5 0], eps); %!test <*34405> %! assert (betainc (NaN, 1, 2), NaN); @@ -270,6 +290,9 @@ %!error <all inputs must be real> betainc (0.5i, 1, 2) %!error <all inputs must be real> betainc (0, 1i, 1) %!error <all inputs must be real> betainc (0, 1, 1i) +%!error <all inputs must be numeric> betainc (char (1), 1, 2) +%!error <all inputs must be numeric> betainc (0, char (1), 1) +%!error <all inputs must be numeric> betainc (0, 1, char (1)) %!error <X must be in the range \[0, 1\]> betainc (-0.1,1,1) %!error <X must be in the range \[0, 1\]> betainc (1.1,1,1) %!error <X must be in the range \[0, 1\]>