changeset 32926:39b341c858a7

nchoosek: Improve input validation and code clarity (bug #65238) * NEWS.10.md: Announce improvements in speed and precision. * nchoosek.m: Use try/catch block to emit meaningful error message when inputs N, K are of differing integer types. Rename CamelCase variable "Mx" to "imax" (integer max) for clarity. Add more code comments. Add BIST tests for incompatible input types and for combinations of integer/floating point and double/single types.
author Rik <rik@octave.org>
date Sat, 03 Feb 2024 18:35:17 -0800
parents 2bd9d8c9287e
children 4b99c92fc2b2
files etc/NEWS.10.md scripts/specfun/nchoosek.m
diffstat 2 files changed, 34 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/etc/NEWS.10.md	Sat Feb 03 19:16:08 2024 +0100
+++ b/etc/NEWS.10.md	Sat Feb 03 18:35:17 2024 -0800
@@ -15,6 +15,8 @@
   * `--no-init-site` : Don't read site-wide configuration files at startup.
   * `--no-init-all` : Don't read any configuration files at startup.
 
+- `nchoosek` algorithm is now ~2x faster and provides greater precision. 
+
 ### Graphical User Interface
 
 ### Graphics backend
--- a/scripts/specfun/nchoosek.m	Sat Feb 03 19:16:08 2024 +0100
+++ b/scripts/specfun/nchoosek.m	Sat Feb 03 18:35:17 2024 -0800
@@ -112,33 +112,37 @@
   n = numel (v);
 
   if (n == 1 && isnumeric (v))
-
-    k = min (k, v-k);
-    is_int = (isinteger (v) || isinteger (k));
-    ## use Octave type propagation rules & validation of numeric data types by the range
-    if is_int
-      Mx = intmax ((v-k):v);
+    ## Compute number of combinations rather than actual set combinations.
+    try
+      ## Use subtraction operation to validate combining integer data types
+      ## and for type propagation rules between integer and floating point.
+      k = min (k, v-k);
+    catch
+      error ("nchoosek: incompatible input types for N (%s), K (%s)", ...
+             class (v), class (k));
+    end_try_catch
+    is_int = isinteger (k);
+    if (is_int)
+      imax = intmax (k);
     else
-      Mx = flintmax ((v-k):v);
+      imax = flintmax (k);
     endif
     C = 1;
     for i = 1:k
-      if (C * (v - k + i) >= Mx)
+      if (C * (v - k + i) >= imax)
         ## Avoid overflow / precision loss by determining the smallest
         ## possible factor of (C * (n-k+i)) and i via the gcd.
         ## Note that by design in each iteration
-        ##    1)  C will always increase (factor is always > 1)
-        ##    2)  C will always be a whole number.
+        ##   1) C will always increase (factor is always > 1).
+        ##   2) C will always be a whole number.
         ## Therefore, using the gcd will always provide the best possible
         ## solution until saturation / has the least precision loss.
-
         g1 = gcd (C, i);
         g2 = gcd (v - k + i, i/g1);
         C /= g1;
         C *= (v - k + i)/g2;
-        if (is_int && (C >= Mx))
-          ## We can finish here. Saturation will be reached
-          break;
+        if (is_int && (C == imax))
+          break;  # Stop here; saturation reached.
         endif
         C /= i/(g1 * g2);
       else
@@ -146,13 +150,15 @@
         C /= i;
       endif
     endfor
-    if (! is_int && C > Mx)
+    if (! is_int && C > imax)
       warning ("Octave:nchoosek:large-output-float", ...
                "nchoosek: possible loss of precision");
-    elseif (is_int && C == Mx)
+    elseif (is_int && C == imax)
       warning ("Octave:nchoosek:large-output-integer", ...
                "nchoosek: result may have saturated at intmax");
     endif
+
+  ## Compute actual set combinations
   elseif (k == 0)
     C = v(zeros (1, 0));  # Return 1x0 object for Matlab compatibility
   elseif (k == 1)
@@ -282,6 +288,15 @@
 %! assert (x, uint8 (252));
 %! assert (class (x), "uint8");
 
+## Test combining rules for integers and floating point
+%!test
+%! x = nchoosek (uint8 (10), single (5));
+%! assert (x, uint8 (252));
+
+%!test
+%! x = nchoosek (double (10), single (5));
+%! assert (x, single (252));
+
 %!test <*63538>
 %! x = nchoosek ([1:3]', 2);
 %! assert (x, [1 2; 1 3; 2 3]);
@@ -298,6 +313,7 @@
 %!error <N must be a non-negative integer .= K> nchoosek (100, 145)
 %!error <N must be a non-negative integer .= K> nchoosek (-100, 45)
 %!error <N must be a non-negative integer .= K> nchoosek (100.5, 45)
+%!error <incompatible input types> nchoosek (uint8 (15), uint16 (5))
 %!warning <possible loss of precision> nchoosek (100, 45);
 %!warning <result .* saturated> nchoosek (uint64 (80), uint64 (40));
 %!warning <result .* saturated> nchoosek (uint32 (80), uint32 (40));