changeset 20485:b3f56ed8d1f9

poissinv.m: Overhaul function for accuracy and performance (bug #34363). * poissinv.m: Eliminate while(1) loop which may never terminate. Subsitute for loop with limit of 20. If solution has not converged, use Mike Gile's analytic_approx routine to finish calculating a solution. Introduce special case for scalar lambda which speeds up calculation 6X.
author Lachlan Andrew <lachlanbis@gmail.com>
date Tue, 18 Aug 2015 11:35:23 -0700
parents df4165dfc676
children 9267c95dbd71
files scripts/statistics/distributions/poissinv.m
diffstat 1 files changed, 130 insertions(+), 20 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/statistics/distributions/poissinv.m	Mon Aug 17 09:32:20 2015 -0700
+++ b/scripts/statistics/distributions/poissinv.m	Tue Aug 18 11:35:23 2015 -0700
@@ -1,5 +1,7 @@
+## Copyright (C) 1995-2015 Kurt Hornik
+## Copyright (C) 2015 Lachlan Andrew
+## Copyright (C) 2014 Mike Giles
 ## Copyright (C) 2012 Rik Wehbring
-## Copyright (C) 1995-2015 Kurt Hornik
 ##
 ## This file is part of Octave.
 ##
@@ -23,7 +25,10 @@
 ## at @var{x} of the Poisson distribution with parameter @var{lambda}.
 ## @end deftypefn
 
-## Author: KH <Kurt.Hornik@wu-wien.ac.at>
+## Author: Lachlan <lachlanbis@gmail.com>
+## based on code by
+##      KH <Kurt.Hornik@wu-wien.ac.at>
+##      Mike Giles <mike.giles@maths.ox.ac.uk>
 ## Description: Quantile function of the Poisson distribution
 
 function inv = poissinv (x, lambda)
@@ -55,27 +60,132 @@
   k = (x == 1) & (lambda > 0);
   inv(k) = Inf;
 
-  k = find ((x > 0) & (x < 1) & (lambda > 0));
-  if (isscalar (lambda))
-    cdf = exp (-lambda) * ones (size (k));
-  else
-    cdf = exp (-lambda(k));
+  k = (x > 0) & (x < 1) & (lambda > 0);
+
+  limit = 20;                           # After 'limit' iterations, use approx
+  if (! isempty (k))
+    if (isscalar (lambda))
+      cdf = [(cumsum (poisspdf (0:limit-1,lambda))), 2];
+      y = x(:);                         # force to column
+      r = bsxfun (@le, y(k), cdf);
+      [~, inv(k)] = max (r, [], 2);     # find first instance of x <= cdf
+      inv(k) -= 1;
+    else
+      kk = find (k);
+      cdf = exp (-lambda(kk));
+      for i = 1:limit
+        m = find (cdf < x(kk));
+        if (any (m))
+          inv(kk(m)) += 1;
+          cdf(m) += poisspdf (i, lambda(kk(m)));
+        else
+          break;
+        endif
+      endfor
+    endif
+
+    ## Use Mike Giles's magic when inv isn't < limit
+    k &= (inv == limit);
+    if (any (k(:)))
+      if (isscalar (lambda))
+        lam = repmat (lambda, size (x));
+      else
+        lam = lambda;
+      endif
+      inv(k) = analytic_approx (x(k), lam(k));
+    endif
   endif
 
-  while (1)
-    m = find (cdf < x(k));
-    if (any (m))
-      inv(k(m)) += 1;
-      if (isscalar (lambda))
-        cdf(m) = cdf(m) + poisspdf (inv(k(m)), lambda);
-      else
-        cdf(m) = cdf(m) + poisspdf (inv(k(m)), lambda(k(m)));
+endfunction
+
+
+## The following is based on Mike Giles's CUDA implementation,
+## [http://people.maths.ox.ac.uk/gilesm/codes/poissinv/poissinv_cuda.h]
+## which is copyright by the University of Oxford
+## and is provided under the terms of the GNU GPLv3 license:
+## http://www.gnu.org/licenses/gpl.html
+
+function inv = analytic_approx (x, lambda)
+  s = norminv (x, 0, 1) ./ sqrt (lambda);
+  k = (s > -0.6833501) & (s < 1.777993);
+  ## use polynomial approximations in central region
+  if (any (k))
+    lam = lambda(k);
+    if (isscalar (s))
+      sk = s;
+    else
+      sk = s(k);
+    endif
+
+    ## polynomial approximation to f^{-1}(s) - 1
+    rm =  2.82298751e-07;
+    rm = -2.58136133e-06 + rm.*sk;
+    rm =  1.02118025e-05 + rm.*sk;
+    rm = -2.37996199e-05 + rm.*sk;
+    rm =  4.05347462e-05 + rm.*sk;
+    rm = -6.63730967e-05 + rm.*sk;
+    rm =  0.000124762566 + rm.*sk;
+    rm = -0.000256970731 + rm.*sk;
+    rm =  0.000558953132 + rm.*sk;
+    rm =  -0.00133129194 + rm.*sk;
+    rm =   0.00370367937 + rm.*sk;
+    rm =   -0.0138888706 + rm.*sk;
+    rm =     0.166666667 + rm.*sk;
+    rm =         sk + sk.*(rm.*sk);
+
+    ## polynomial approximation to correction c0(r)
+
+    t  =   1.86386867e-05;
+    t  =  -0.000207319499 + t.*rm;
+    t  =     0.0009689451 + t.*rm;
+    t  =   -0.00247340054 + t.*rm;
+    t  =    0.00379952985 + t.*rm;
+    t  =   -0.00386717047 + t.*rm;
+    t  =    0.00346960934 + t.*rm;
+    t  =   -0.00414125511 + t.*rm;
+    t  =    0.00586752093 + t.*rm;
+    t  =   -0.00838583787 + t.*rm;
+    t  =     0.0132793933 + t.*rm;
+    t  =     -0.027775536 + t.*rm;
+    t  =      0.333333333 + t.*rm;
+
+    ##  O(1/lam) correction
+
+    y  =   -0.00014585224;
+    y  =    0.00146121529 + y.*rm;
+    y  =   -0.00610328845 + y.*rm;
+    y  =     0.0138117964 + y.*rm;
+    y  =    -0.0186988746 + y.*rm;
+    y  =     0.0168155118 + y.*rm;
+    y  =     -0.013394797 + y.*rm;
+    y  =     0.0135698573 + y.*rm;
+    y  =    -0.0155377333 + y.*rm;
+    y  =     0.0174065334 + y.*rm;
+    y  =    -0.0198011178 + y.*rm;
+    y ./= lam;
+
+    inv(k) = floor (lam + (y+t)+lam.*rm);
+  endif
+
+  k = !k & (s > -sqrt (2));
+  if (any (k))
+    ## Newton iteration
+    r = 1 + s(k);
+    r2 = r + 1;
+    while (any (abs (r - r2) > 1e-5))
+      t = log (r);
+      r2 = r;
+      s2 = sqrt (2 * ((1-r) + r.*t));
+      s2(r<1) *= -1;
+      r = r2 - (s2 - s(k)) .* s2 ./ t;
+      if (r < 0.1 * r2)
+        r = 0.1 * r2;
       endif
-    else
-      break;
-    endif
-  endwhile
-
+    endwhile
+    t = log (r);
+    y = lambda(k) .* r + log (sqrt (2*r.*((1-r) + r.*t)) ./ abs (r-1)) ./ t;
+    inv(k) = floor (y - 0.0218 ./ (y + 0.065 * lambda(k)));
+  endif
 endfunction