Mercurial > octave
changeset 32924:2bd9d8c9287e
nchoosek: Improve numerical stability and performance (bug #65238).
* script/specfun/nchoosek.m: Use iterative algorithm to prevent precision loss
due to overflow. Calculate greatest common denominator only when required.
Improves performance by a factor of 2 for some use-cases.
author | Hendrik Koerner <koerhen@web.de> |
---|---|
date | Sat, 03 Feb 2024 19:16:08 +0100 |
parents | ab643c7c1c10 |
children | 39b341c858a7 |
files | scripts/specfun/nchoosek.m |
diffstat | 1 files changed, 32 insertions(+), 33 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/specfun/nchoosek.m Sat Feb 03 15:33:45 2024 +0100 +++ b/scripts/specfun/nchoosek.m Sat Feb 03 19:16:08 2024 +0100 @@ -112,45 +112,44 @@ n = numel (v); if (n == 1 && isnumeric (v)) - ## Improve precision over direct call to prod(). - ## Steps: 1) Make a list of integers for numerator and denominator, - ## 2) filter out common factors, 3) multiply what remains. + k = min (k, v-k); - - if (isinteger (v) || isinteger (k)) - numer = (v-k+1):v; - denom = (1:k); + is_int = (isinteger (v) || isinteger (k)); + ## use Octave type propagation rules & validation of numeric data types by the range + if is_int + Mx = intmax ((v-k):v); else - ## For a ~25% performance boost, multiply values pairwise so there - ## are fewer elements in do/until loop which is the slow part. - ## Since Odd*Even is guaranteed to be Even, also take out a factor - ## of 2 from numerator and denominator. - if (rem (k, 2)) # k is odd - numer = [((v-k+1:v-(k+1)/2) .* (v-1:-1:v-(k-1)/2)) / 2, v]; - denom = [((1:(k-1)/2) .* (k-1:-1:(k+1)/2)) / 2, k]; - else # k is even - numer = ((v-k+1:v-k/2) .* (v:-1:v-k/2+1)) / 2; - denom = ((1:k/2) .* (k:-1:k/2+1)) / 2; - endif + Mx = flintmax ((v-k):v); endif + C = 1; + for i = 1:k + if (C * (v - k + i) >= Mx) + ## Avoid overflow / precision loss by determining the smallest + ## possible factor of (C * (n-k+i)) and i via the gcd. + ## Note that by design in each iteration + ## 1) C will always increase (factor is always > 1) + ## 2) C will always be a whole number. + ## Therefore, using the gcd will always provide the best possible + ## solution until saturation / has the least precision loss. - ## Remove common factors from numerator and denominator - do - for i = numel (denom):-1:1 - factors = gcd (denom(i), numer); - [f, j] = max (factors); - denom(i) /= f; - numer(j) /= f; - endfor - denom = denom(denom > 1); - numer = numer(numer > 1); - until (isempty (denom)) - - C = prod (numer, "native"); - if (isfloat (C) && C > flintmax (C)) + g1 = gcd (C, i); + g2 = gcd (v - k + i, i/g1); + C /= g1; + C *= (v - k + i)/g2; + if (is_int && (C >= Mx)) + ## We can finish here. Saturation will be reached + break; + endif + C /= i/(g1 * g2); + else + C *= (v - k + i); + C /= i; + endif + endfor + if (! is_int && C > Mx) warning ("Octave:nchoosek:large-output-float", ... "nchoosek: possible loss of precision"); - elseif (isinteger (C) && C == intmax (C)) + elseif (is_int && C == Mx) warning ("Octave:nchoosek:large-output-integer", ... "nchoosek: result may have saturated at intmax"); endif