changeset 32924:2bd9d8c9287e

nchoosek: Improve numerical stability and performance (bug #65238). * script/specfun/nchoosek.m: Use iterative algorithm to prevent precision loss due to overflow. Calculate greatest common denominator only when required. Improves performance by a factor of 2 for some use-cases.
author Hendrik Koerner <koerhen@web.de>
date Sat, 03 Feb 2024 19:16:08 +0100
parents ab643c7c1c10
children 39b341c858a7
files scripts/specfun/nchoosek.m
diffstat 1 files changed, 32 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/specfun/nchoosek.m	Sat Feb 03 15:33:45 2024 +0100
+++ b/scripts/specfun/nchoosek.m	Sat Feb 03 19:16:08 2024 +0100
@@ -112,45 +112,44 @@
   n = numel (v);
 
   if (n == 1 && isnumeric (v))
-    ## Improve precision over direct call to prod().
-    ## Steps: 1) Make a list of integers for numerator and denominator,
-    ## 2) filter out common factors, 3) multiply what remains.
+
     k = min (k, v-k);
-
-    if (isinteger (v) || isinteger (k))
-      numer = (v-k+1):v;
-      denom = (1:k);
+    is_int = (isinteger (v) || isinteger (k));
+    ## use Octave type propagation rules & validation of numeric data types by the range
+    if is_int
+      Mx = intmax ((v-k):v);
     else
-      ## For a ~25% performance boost, multiply values pairwise so there
-      ## are fewer elements in do/until loop which is the slow part.
-      ## Since Odd*Even is guaranteed to be Even, also take out a factor
-      ## of 2 from numerator and denominator.
-      if (rem (k, 2))  # k is odd
-        numer = [((v-k+1:v-(k+1)/2) .* (v-1:-1:v-(k-1)/2)) / 2, v];
-        denom = [((1:(k-1)/2) .* (k-1:-1:(k+1)/2)) / 2, k];
-      else             # k is even
-        numer = ((v-k+1:v-k/2) .* (v:-1:v-k/2+1)) / 2;
-        denom = ((1:k/2) .* (k:-1:k/2+1)) / 2;
-      endif
+      Mx = flintmax ((v-k):v);
     endif
+    C = 1;
+    for i = 1:k
+      if (C * (v - k + i) >= Mx)
+        ## Avoid overflow / precision loss by determining the smallest
+        ## possible factor of (C * (n-k+i)) and i via the gcd.
+        ## Note that by design in each iteration
+        ##    1)  C will always increase (factor is always > 1)
+        ##    2)  C will always be a whole number.
+        ## Therefore, using the gcd will always provide the best possible
+        ## solution until saturation / has the least precision loss.
 
-    ## Remove common factors from numerator and denominator
-    do
-      for i = numel (denom):-1:1
-        factors = gcd (denom(i), numer);
-        [f, j] = max (factors);
-        denom(i) /= f;
-        numer(j) /= f;
-      endfor
-      denom = denom(denom > 1);
-      numer = numer(numer > 1);
-    until (isempty (denom))
-
-    C = prod (numer, "native");
-    if (isfloat (C) && C > flintmax (C))
+        g1 = gcd (C, i);
+        g2 = gcd (v - k + i, i/g1);
+        C /= g1;
+        C *= (v - k + i)/g2;
+        if (is_int && (C >= Mx))
+          ## We can finish here. Saturation will be reached
+          break;
+        endif
+        C /= i/(g1 * g2);
+      else
+        C *= (v - k + i);
+        C /= i;
+      endif
+    endfor
+    if (! is_int && C > Mx)
       warning ("Octave:nchoosek:large-output-float", ...
                "nchoosek: possible loss of precision");
-    elseif (isinteger (C) && C == intmax (C))
+    elseif (is_int && C == Mx)
       warning ("Octave:nchoosek:large-output-integer", ...
                "nchoosek: result may have saturated at intmax");
     endif