changeset 9754:4219e5cf773d

improve interp1 and pchip
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 22 Oct 2009 10:12:54 +0200
parents 892e2aa7bc75
children 4f4873f6f873
files scripts/ChangeLog scripts/general/interp1.m scripts/polynomial/pchip.m
diffstat 3 files changed, 129 insertions(+), 138 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/ChangeLog	Thu Oct 22 08:56:58 2009 +0200
+++ b/scripts/ChangeLog	Thu Oct 22 10:12:54 2009 +0200
@@ -1,3 +1,9 @@
+2009-10-22  Jaroslav Hajek  <highegg@gmail.com>
+
+	* general/interp1.m: Perform optimizations, improve code (use switch
+	instead of multiple ifs).
+	* polynomial/pchip.m: Employ more optimized formulas (from SLATEC).
+
 2009-10-22  Soren Hauberg  <hauberg@gmail.com>
 
 	* image/autumn.m, image/bone.m, image/cool.m, image/copper.m,
--- a/scripts/general/interp1.m	Thu Oct 22 08:56:58 2009 +0200
+++ b/scripts/general/interp1.m	Thu Oct 22 10:12:54 2009 +0200
@@ -126,15 +126,13 @@
 
   ## reshape matrices for convenience
   x = x(:);
-  nx = size (x, 1);
-  if (isvector(y) && size (y, 1) == 1)
+  nx = rows (x);
+  if (isvector (y))
     y = y(:);
   endif
-  ndy = ndims (y);
   szy = size (y);
-  ny = szy(1);
-  nc = prod (szy(2:end));
-  y = reshape (y, ny, nc);
+  y = y(:,:);
+  [ny, nc] = size (y);
   szx = size (xi);
   xi = xi(:);
 
@@ -143,61 +141,42 @@
     error ("interp1: table too short");
   endif
 
-  ## determine which values are out of range and set them to extrap,
-  ## unless extrap == "extrap" in which case, extrapolate them like we
-  ## should be doing in the first place.
-  minx = x(1);
-  maxx = x(nx);
-  if (minx > maxx)
-    tmp = minx;
-    minx = maxx;
-    maxx = tmp;
-  endif
-  if (method(1) == "*")
-    dx = x(2) - x(1);
+  ## check whether x is sorted; sort if not.
+  if (! issorted (x))
+    [x, p] = sort (x);
+    y = y(p,:);
   endif
 
-  if (! pp)
-    if (ischar (extrap) && strcmp (extrap, "extrap"))
-      range = 1:size (xi, 1);
-      yi = zeros (size (xi, 1), size (y, 2));
-    else
-      range = find (xi >= minx & xi <= maxx);
-      yi = extrap * ones (size (xi, 1), size (y, 2));
-      if (isempty (range))
-	if (! isvector (y) && length (szx) == 2
-	    && (szx(1) == 1 || szx(2) == 1))
-	  if (szx(1) == 1)
-	    yi = reshape (yi, [szx(2), szy(2:end)]);
-	  else
-	    yi = reshape (yi, [szx(1), szy(2:end)]);
-	  endif
-	else
-	  yi = reshape (yi, [szx, szy(2:end)]);
-        endif
-        return; 
-      endif
-      xi = xi(range);
+  starmethod = method(1) == "*";
+
+  if (starmethod)
+    dx = x(2) - x(1);
+  else
+    if (any (x(1:nx-1) == x(2:nx)))
+      error ("interp1: table must be strictly monotonic");
     endif
   endif
 
-  if (strcmp (method, "nearest"))
+  ## Proceed with interpolating by all methods.
+
+  switch (method)
+  case "nearest"
     if (pp)
       yi = mkpp ([x(1); (x(1:end-1)+x(2:end))/2; x(end)], y, szy(2:end));
     else
       idx = lookup (0.5*(x(1:nx-1)+x(2:nx)), xi) + 1;
-      yi(range,:) = y(idx,:);
+      yi = y(idx,:);
     endif
-  elseif (strcmp (method, "*nearest"))
+  case "*nearest"
     if (pp)
       yi = mkpp ([x(1); x(1)+[0.5:(ny-1)]'*dx; x(nx)], y, szy(2:end));
     else
       idx = max (1, min (ny, floor((xi-x(1))/dx+1.5)));
-      yi(range,:) = y(idx,:);
+      yi = y(idx,:);
     endif
-  elseif (strcmp (method, "linear"))
-    dy = y(2:ny,:) - y(1:ny-1,:);
-    dx = x(2:nx) - x(1:nx-1);
+  case "linear"
+    dy = diff (y);
+    dx = diff (x);
     if (pp)
       yi = mkpp (x, [dy./dx, y(1:end-1)], szy(2:end));
     else
@@ -205,24 +184,23 @@
       idx = lookup (x, xi, "lr");
       ## use the endpoints of the interval to define a line
       s = (xi - x(idx))./dx(idx);
-      yi(range,:) = s(:,ones(1,nc)).*dy(idx,:) + y(idx,:);
+      yi = bsxfun (@times, s, dy(idx,:)) + y(idx,:);
     endif
-  elseif (strcmp (method, "*linear"))
+  case "*linear"
+    dy = diff (y);
     if (pp)
-      dy = [y(2:ny,:) - y(1:ny-1,:)];
       yi = mkpp (x(1) + [0:ny-1]*dx, [dy./dx, y(1:end-1)], szy(2:end));
     else
       ## find the interval containing the test point
       t = (xi - x(1))/dx + 1;
-      idx = max(1,min(ny,floor(t)));
+      idx = max (1, min (ny - 1, floor (t)));
 
       ## use the endpoints of the interval to define a line
-      dy = [y(2:ny,:) - y(1:ny-1,:); y(ny,:) - y(ny-1,:)];
       s = t - idx;
-      yi(range,:) = s(:,ones(1,nc)).*dy(idx,:) + y(idx,:); 
+      yi = bsxfun (@times, s, dy(idx,:)) + y(idx,:);
     endif
-  elseif (strcmp (method, "pchip") || strcmp (method, "*pchip"))
-    if (nx == 2 || method(1) == "*") 
+  case {"pchip", "*pchip"}
+    if (nx == 2 || starmethod) 
       x = linspace (x(1), x(nx), ny);
     endif
     ## Note that pchip's arguments are transposed relative to interp1
@@ -230,69 +208,67 @@
       yi = pchip (x.', y.');
       yi.d = szy(2:end);
     else
-      yi(range,:) = pchip (x.', y.', xi.').';
-    endif
-
-  elseif (strcmp (method, "cubic") || (strcmp (method, "*cubic") && pp))
-    ## FIXME Is there a better way to treat pp return return and *cubic
-    if (method(1) == "*")
-      x = linspace (x(1), x(nx), ny).'; 
-      nx = ny;
+      yi = pchip (x.', y.', xi.').';
     endif
 
-    if (nx < 4 || ny < 4)
-      error ("interp1: table too short");
-    endif
-    idx = lookup (x(2:nx-1), xi, "lr");
-
-    ## Construct cubic equations for each interval using divided
-    ## differences (computation of c and d don't use divided differences
-    ## but instead solve 2 equations for 2 unknowns). Perhaps
-    ## reformulating this as a lagrange polynomial would be more efficient.
-    i = 1:nx-3;
-    J = ones (1, nc);
-    dx = diff (x);
-    dx2 = x(i+1).^2 - x(i).^2;
-    dx3 = x(i+1).^3 - x(i).^3;
-    a = diff (y, 3)./dx(i,J).^3/6;
-    b = (diff (y(1:nx-1,:), 2)./dx(i,J).^2 - 6*a.*x(i+1,J))/2;
-    c = (diff (y(1:nx-2,:), 1) - a.*dx3(:,J) - b.*dx2(:,J))./dx(i,J);
-    d = y(i,:) - ((a.*x(i,J) + b).*x(i,J) + c).*x(i,J);
-
-    if (pp)
-      xs = [x(1);x(3:nx-2)];
-      yi = mkpp ([x(1);x(3:nx-2);x(nx)], 
-		 [a(:), (b(:) + 3.*xs(:,J).*a(:)), ... 
-		  (c(:) + 2.*xs(:,J).*b(:) + 3.*xs(:,J)(:).^2.*a(:)), ...
-		  (d(:) + xs(:,J).*c(:) + xs(:,J).^2.*b(:) + ...
-		   xs(:,J).^3.*a(:))], szy(2:end));
-    else
-      yi(range,:) = ((a(idx,:).*xi(:,J) + b(idx,:)).*xi(:,J) ...
-		     + c(idx,:)).*xi(:,J) + d(idx,:);
-    endif
-  elseif (strcmp (method, "*cubic"))
+  case {"cubic", "*cubic"}
     if (nx < 4 || ny < 4)
       error ("interp1: table too short");
     endif
 
-    ## From: Miloje Makivic 
-    ## http://www.npac.syr.edu/projects/nasa/MILOJE/final/node36.html
-    t = (xi - x(1))/dx + 1;
-    idx = max (min (floor (t), ny-2), 2);
-    t = t - idx;
-    t2 = t.*t;
-    tp = 1 - 0.5*t;
-    a = (1 - t2).*tp;
-    b = (t2 + t).*tp;
-    c = (t2 - t).*tp/3;
-    d = (t2 - 1).*t/6;
-    J = ones (1, nc);
+    ## FIXME Is there a better way to treat pp return and *cubic
+    if (starmethod && ! pp)
+      ## From: Miloje Makivic 
+      ## http://www.npac.syr.edu/projects/nasa/MILOJE/final/node36.html
+      t = (xi - x(1))/dx + 1;
+      idx = max (min (floor (t), ny-2), 2);
+      t = t - idx;
+      t2 = t.*t;
+      tp = 1 - 0.5*t;
+      a = (1 - t2).*tp;
+      b = (t2 + t).*tp;
+      c = (t2 - t).*tp/3;
+      d = (t2 - 1).*t/6;
+      J = ones (1, nc);
+
+      yi = a(:,J) .* y(idx,:) + b(:,J) .* y(idx+1,:) ...
+      + c(:,J) .* y(idx-1,:) + d(:,J) .* y(idx+2,:);
+    else
+      if (starmethod)
+        x = linspace (x(1), x(nx), ny).'; 
+        nx = ny;
+      endif
+
+      idx = lookup (x(2:nx-1), xi, "lr");
 
-    yi(range,:) = a(:,J) .* y(idx,:) + b(:,J) .* y(idx+1,:) ...
-		  + c(:,J) .* y(idx-1,:) + d(:,J) .* y(idx+2,:);
+      ## Construct cubic equations for each interval using divided
+      ## differences (computation of c and d don't use divided differences
+      ## but instead solve 2 equations for 2 unknowns). Perhaps
+      ## reformulating this as a lagrange polynomial would be more efficient.
+      i = 1:nx-3;
+      J = ones (1, nc);
+      dx = diff (x);
+      dx2 = x(i+1).^2 - x(i).^2;
+      dx3 = x(i+1).^3 - x(i).^3;
+      a = diff (y, 3)./dx(i,J).^3/6;
+      b = (diff (y(1:nx-1,:), 2)./dx(i,J).^2 - 6*a.*x(i+1,J))/2;
+      c = (diff (y(1:nx-2,:), 1) - a.*dx3(:,J) - b.*dx2(:,J))./dx(i,J);
+      d = y(i,:) - ((a.*x(i,J) + b).*x(i,J) + c).*x(i,J);
 
-  elseif (strcmp (method, "spline") || strcmp (method, "*spline"))
-    if (nx == 2 || method(1) == "*") 
+      if (pp)
+        xs = [x(1);x(3:nx-2)];
+        yi = mkpp ([x(1);x(3:nx-2);x(nx)], 
+                   [a(:), (b(:) + 3.*xs(:,J).*a(:)), ... 
+                    (c(:) + 2.*xs(:,J).*b(:) + 3.*xs(:,J)(:).^2.*a(:)), ...
+                    (d(:) + xs(:,J).*c(:) + xs(:,J).^2.*b(:) + ...
+                     xs(:,J).^3.*a(:))], szy(2:end));
+      else
+        yi = ((a(idx,:).*xi(:,J) + b(idx,:)).*xi(:,J) ...
+              + c(idx,:)).*xi(:,J) + d(idx,:);
+      endif
+    endif
+  case {"spline", "*spline"}
+    if (nx == 2 || starmethod) 
       x = linspace(x(1), x(nx), ny); 
     endif
     ## Note that spline's arguments are transposed relative to interp1
@@ -300,13 +276,23 @@
       yi = spline (x.', y.');
       yi.d = szy(2:end);
     else
-      yi(range,:) = spline (x.', y.', xi.').';
+      yi = spline (x.', y.', xi.').';
     endif
-  else
+  otherwise
     error ("interp1: invalid method '%s'", method);
-  endif
+  endswitch
 
   if (! pp)
+    if (! ischar (extrap))
+      ## determine which values are out of range and set them to extrap,
+      ## unless extrap == "extrap".
+      minx = min (x(1), x(nx));
+      maxx = max (x(1), x(nx));
+
+      outliers = xi < minx | ! (xi <= maxx); # this catches even NaNs
+      yi(outliers, :) = extrap;
+    endif
+
     if (! isvector (y) && length (szx) == 2 && (szx(1) == 1 || szx(2) == 1))
       if (szx(1) == 1)
 	yi = reshape (yi, [szx(2), szy(2:end)]);
--- a/scripts/polynomial/pchip.m	Thu Oct 22 08:56:58 2009 +0200
+++ b/scripts/polynomial/pchip.m	Thu Oct 22 10:12:54 2009 +0200
@@ -73,50 +73,49 @@
     print_usage ();
   endif
 
-  x = x(:);
+  x = x(:).';
   n = length (x);
 
   ## Check the size and shape of y
-  ndy = ndims (y);
-  szy = size (y);
-  if (ndy == 2 && (szy(1) == 1 || szy(2) == 1))
-    if (szy(1) == 1)
-      y = y';
-    else
-      szy = fliplr (szy);
-    endif
+  if (isvector (y))
+    y = y(:).';
+    szy = size (y);
   else
-    y = reshape (y, [prod(szy(1:end-1)), szy(end)])';
+    szy = size (y);
+    y = reshape (y, [prod(szy(1:end-1)), szy(end)]);
   endif
 
   h = diff (x);
   if (all (h < 0))
-    x = flipud (x);
+    x = fliplr (x);
     h = diff (x);
-    y = flipud (y);
+    y = fliplr (y);
   elseif (any (h <= 0))
     error("pchip: x must be strictly monotonic");
   endif
 
-  if (rows (y) != n)
+  if (columns (y) != n)
     error("pchip: size of x and y must match");
   endif
 
-  [ry, cy] = size (y);
-  if (cy > 1)
-    h = kron (diff (x), ones (1, cy));
-  endif
-  
-  dy = diff (y) ./ h;
+  f1 = y(:,1:n-1);
+
+  ## Compute derivatives.
+  d = __pchip_deriv__ (x, y, 2);
+  d1 = d(:,1:n-1);
+  d2 = d(:,2:n);
 
-  a = y;
-  b = __pchip_deriv__ (x, y);
-  c = - (b(2:n, :) + 2 * b(1:n - 1, :)) ./ h + 3 * diff (a) ./ h .^ 2;
-  d = (b(1:n - 1, :) + b(2:n, :)) ./ h.^2 - 2 * diff (a) ./ h.^3;
+  ## This is taken from SLATEC. 
+  h = diag (h);
 
-  d = d(1:n - 1, :); c = c(1:n - 1, :);
-  b = b(1:n - 1, :); a = a(1:n - 1, :);
-  coeffs = [d(:), c(:), b(:), a(:)];
+  delta = diff (y, 1, 2) / h;
+  del1 = (d1 - delta) / h;
+  del2 = (d2 - delta) / h;
+  c3 = del1 + del2;
+  c2 = -c3 - del1;
+  c3 = c3 / h;
+
+  coeffs = [c3.'(:), c2.'(:), d1.'(:), f1.'(:)];
   pp = mkpp (x, coeffs, szy(1:end-1));
 
   if (nargin == 2)