changeset 20635:a22d8a2eb0e5

fix adaptive strategy in ode solvers. * script/ode/ode45.m: remove unused option OutputSave * script/ode/private/integrate_adaptive.m: rewrite algorithm to be more compatible. * script/ode/private/runge_kutta_45_dorpri.m: use kahan summation for time increment.
author Carlo de Falco <carlo.defalco@polimi.it>
date Sun, 11 Oct 2015 18:44:58 +0200
parents 1c5a86b7f838
children 43822bda4f65
files scripts/ode/ode45.m scripts/ode/private/integrate_adaptive.m scripts/ode/private/runge_kutta_45_dorpri.m
diffstat 3 files changed, 181 insertions(+), 218 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/ode/ode45.m	Sat Oct 10 16:52:59 2015 -0700
+++ b/scripts/ode/ode45.m	Sun Oct 11 18:44:58 2015 +0200
@@ -238,10 +238,10 @@
     vodeoptions.vhaveoutputselection = false; 
   endif
 
-  ## Implementation of the option OutputSave has been finished.
-  ## This option can be set by the user to another value than default value.
-  if (isempty (vodeoptions.OutputSave))
-    vodeoptions.OutputSave = 1;
+  ## "OutputSave" option will be ignored.
+  if (! isempty (vodeoptions.OutputSave))
+    warning ("OdePkg:InvalidArgument",
+             "Option 'OutputSave' will be ignored.");
   endif
 
   ## Implementation of the option Refine has been finished.
@@ -406,12 +406,12 @@
   ## Print additional information if option Stats is set
   if (strcmp (vodeoptions.Stats, "on"))
     vhavestats = true;
-    vnsteps    = solution.vcntloop-2;        # vcntloop from 2..end
+    vnsteps    = solution.vcntloop-2;                 # vcntloop from 2..end
     vnfailed   = (solution.vcntcycles-1)-(vnsteps)+1; # vcntcycl from 1..end
-    vnfevals   = 7*(solution.vcntcycles-1);  # number of ode evaluations
-    vndecomps  = 0;                          # number of LU decompositions
-    vnpds      = 0;                          # number of partial derivatives
-    vnlinsols  = 0;                          # no. of linear systems solutions
+    vnfevals   = 7*(solution.vcntcycles-1);           # number of ode evaluations
+    vndecomps  = 0;                                   # number of LU decompositions
+    vnpds      = 0;                                   # number of partial derivatives
+    vnlinsols  = 0;                                   # no. of linear systems solutions
     ## Print cost statistics if no output argument is given
     if (nargout == 0)
       printf ("Number of successful steps: %d\n", vnsteps);
@@ -598,12 +598,6 @@
 %!test  # Details of OutputSel and Refine can't be tested
 %! vopt = odeset ("OutputFcn", @fout, "OutputSel", 1, "Refine", 5);
 %! vsol = ode45 (@fpol, [0 2], [2 0], vopt);
-%!test  # Details of OutputSave can't be tested
-%! vopt = odeset ("OutputSave", 1, "OutputSel", 1);
-%! vsla = ode45 (@fpol, [0 2], [2 0], vopt);
-%! vopt = odeset ("OutputSave", 2);
-%! vslb = ode45 (@fpol, [0 2], [2 0], vopt);
-%! assert (length (vsla.x) + 1 >= 2 * length (vslb.x))
 %!test  # Stats must add further elements in vsol
 %! vopt = odeset ("Stats", "on");
 %! vsol = ode45 (@fpol, [0 2], [2 0], vopt);
--- a/scripts/ode/private/integrate_adaptive.m	Sat Oct 10 16:52:59 2015 -0700
+++ b/scripts/ode/private/integrate_adaptive.m	Sun Oct 11 18:44:58 2015 +0200
@@ -62,35 +62,25 @@
 
 function solution = integrate_adaptive (stepper, order, func, tspan, x0, options)
 
-  solution = struct ();
+  fixed_times = numel (tspan) > 2;
 
-  ## first values for time and solution
-  t = tspan(1);
-  x = x0(:);
+  t_new = t_old = t = tspan(1);
+  x_new = x_old = x = x0(:);
   
   ## get first initial timestep
-  dt = odeget (options, "InitialStep",
-               starting_stepsize (order, func, t, x, options.AbsTol,
-                                  options.RelTol, options.vnormcontrol),
-               "fast_not_empty");
-  vdirection = odeget (options, "vdirection", [], "fast");
-  if (sign (dt) != vdirection)
-    dt = -dt;
-  endif
-  dt = vdirection * min (abs (dt), options.MaxStep);
+  dt = starting_stepsize (order, func, t, x, options.AbsTol,
+                          options.RelTol, options.vnormcontrol);
+  dt = odeget (options, "InitialStep", dt, "fast_not_empty");
+  
+  dir = odeget (options, "vdirection", [], "fast");
+  dt = dir * min (abs (dt), options.MaxStep);
 
-  ## Set parameters
-  k = length (tspan);
-  counter = 2;
-  comp = 0.0;
-  tk = tspan(1);
-  options.comp = comp;
+  options.comp = 0.0;
   
   ## Factor multiplying the stepsize guess
   facmin = 0.8;
+  facmax = 1.5;
   fac = 0.38^(1/(order+1));  # formula taken from Hairer
-  t_caught = false;
-
 
   ## Initialize the OutputFcn
   if (options.vhaveoutputfunction)
@@ -105,230 +95,207 @@
 
   ## Initialize the EventFcn
   if (options.vhaveeventfunction)
-    odepkg_event_handle (options.Events, t(end), x,
+    odepkg_event_handle (options.Events, tspan(end), x,
                          "init", options.vfunarguments{:});
   endif
 
+  if (options.vhavenonnegative)
+    nn = options.NonNegative;
+  endif
+  
   solution.vcntloop = 2;
   solution.vcntcycles = 1;
-  vcntiter = 0;
+  solution.vcntsave = 2;
   solution.vunhandledtermination = true;
-  solution.vcntsave = 2;
-  
-  z = t;
-  u = x;
+  ireject = 0;
 
   k_vals = []; 
-  
-  while (counter <= k)
-    facmax = 1.5;
+  iout = istep = 1;
+  while (dir * t_old < dir * tspan(end))
 
-    ## Compute integration step from t to t+dt
-    if (isempty (k_vals))
-      [s, y, y_est, new_k_vals] = stepper (func, z(end), u(:,end),
-                                           dt, options);
-    else
-      [s, y, y_est, new_k_vals] = stepper (func, z(end), u(:,end),
-                                           dt, options, k_vals);
-    endif
-    
+    ## Compute integration step from t_old to t_new = t_old + dt
+    [t_new, options.comp] = kahan (t_old, options.comp, dt);
+    [t_new, x_new, x_est, new_k_vals] = ...
+    stepper (func, t_old, x_old, dt, options, k_vals, t_new);
+
+    solution.vcntcycles++;
+
     if (options.vhavenonnegative)
-      x(options.NonNegative,end) = abs (x(options.NonNegative,end));
-      y(options.NonNegative,end) = abs (y(options.NonNegative,end));
-      y_est(options.NonNegative,end) = abs (y_est(options.NonNegative,end));
-    endif
-
-    if (options.vhaveoutputfunction && options.vhaverefine)
-      vSaveVUForRefine = u(:,end);
+      x(nn,end) = abs (x(nn,end));
+      y(nn,end) = abs (y(nn,end));
+      y_est(nn,end) = abs (y_est(nn,end));
     endif
 
-    err = AbsRel_Norm (y(:,end), u(:,end), options.AbsTol, options.RelTol,
-                       options.vnormcontrol, y_est(:,end));
+    err = AbsRel_Norm (x_new, x_old, options.AbsTol, options.RelTol,
+                       options.vnormcontrol, x_est);
     
-    ## Solution accepted only if the error is less or equal to 1.0
+    ## Accepted solution only if err <= 1.0
     if (err <= 1)
-      
-      [tk, options.comp] = kahan (tk, options.comp, dt);
-      s(end) = tk;
 
-      ## values on this interval for time and solution
-      z = [z(end); s];
-      u = [u(:,end), y];
-      k_vals = new_k_vals;
-      
-      ## if next tspan value is caught, update counter
-      if ((z(end) == tspan(counter))
-          || (abs (z(end) - tspan(counter)) /
-              (max (abs (z(end)), abs (tspan(counter)))) < 8*eps) )
-        counter++;
-        
-        ## if there is an element in time vector at which the solution is
-        ## required the program must compute this solution before going on with
-        ## next steps
-      elseif (vdirection * z(end) > vdirection * tspan(counter))
+      solution.vcntloop++;
+      ireject = 0;
+            
+      ## if output time steps are fixed
+      if (fixed_times)
 
-        ## initialize counter for the following cycle
-        i = 2;
-        while (i <= length (z))
+        t_caught = find ((tspan(iout:end) > t_old)
+                         & (tspan(iout:end) <= t_new));
+        if (! isempty (t_caught))
+          t(t_caught) = tspan(t_caught);
+          iout = max (t_caught);
+          x(:, t_caught) = interpolate ([t_old, t_new], [x_old, x_new],
+                                        t(t_caught));
+
+          istep++;
 
-          ## if next tspan value is caught, update counter
-          if ((counter <= k)
-              && ((z(i) == tspan(counter))
-                  || (abs (z(i) - tspan(counter)) /
-                      (max (abs (z(i)), abs (tspan(counter)))) < 8*eps)) )
-            counter++;
-          endif
-          ## else, loop until there are requested values inside this subinterval
-          while ((counter <= k)
-                 && (vdirection * z(i) > vdirection * tspan(counter)))
-            ## choose interpolation scheme according to order of the solver
-
-            u_interp = ...
-            ode_rk_interpolate (order, [z(i-1) z(i)], [u(:,i-1) u(:,i)],
-                                tspan(counter), k_vals, dt,
-                                options.vfunarguments{:});
-            
-
-            ## add the interpolated value of the solution
-            u = [u(:,1:i-1), u_interp, u(:,i:end)];
-            
-            ## add the time requested
-            z = [z(1:i-1); tspan(counter); z(i:end)];
-
-            ## update counters
-            counter++;
-            i++;
-          endwhile
-
-          ## if new time requested is not out of this interval
-          if ((counter <= k)
-              && (vdirection * z(end) > vdirection * tspan(counter)))
-            ## update the counter
-            i++;
-          else
-            ## stop the cycle and go on with the next iteration
-            i = length (z) + 1;
+          if (options.vhaveeventfunction)
+            ## Call event on each dense output timestep.
+            ##  Stop integration if veventbreak is true
+            break_loop = false;
+            for idenseout = 1:numel (t_caught)
+              id = t_caught(idenseout);
+              td = t(id);
+              solution.vevent = ...
+              odepkg_event_handle (options.Events, t(id), x(:, id), [],
+                                   options.vfunarguments{:});
+              if (! isempty (solution.vevent{1})
+                  && solution.vevent{1} == 1)
+                t(id) = solution.vevent{3}(end);
+                t = t(1:id);
+                x(:, id) = solution.vevent{4}(end, :).';
+                x = x(:,1:id);
+                solution.vunhandledtermination = false; 
+                break_loop = true;
+                break;
+              endif
+            endfor
+            if (break_loop)
+              break;
+            endif
           endif
 
-        endwhile
-      endif
+          ## Call plot.  Stop integration if plot function
+          ## returns true.
+          if (options.vhaveoutputfunction)
+            vcnt = options.Refine + 1;
+            vapproxtime = linspace (t_old, t_new, vcnt);
+            vapproxvals = interp1 ([t_old, t(t_caught), t_new],
+                                   [x_old, x(:, t_caught), x_new],
+                                   vapproxtime, 'linear');
+            if (options.vhaveoutputselection)
+              vapproxvals = vapproxvals(options.OutputSel);
+            endif
+            vpltret = feval (options.OutputFcn, vapproxtime,
+                             vapproxvals, [], options.vfunarguments{:});
+            if (vpltret)  # Leave main loop
+              solution.vunhandledtermination = false;
+              break;
+            endif
+          endif
+
+        endif
+        
+      else
 
-      if (mod (solution.vcntloop-1, options.OutputSave) == 0)
-        x = [x,u(:,2:end)];
-        t = [t;z(2:end)];
-        solution.vcntsave = solution.vcntsave + 1;    
-      endif
-      solution.vcntloop = solution.vcntloop + 1;
-      vcntiter = 0;
-      
-      ## Call plot only if a valid result has been found, therefore this
-      ## code fragment has moved here.  Stop integration if plot function
-      ## returns false
-      if (options.vhaveoutputfunction)
-        for vcnt = 0:options.Refine  # Approximation between told and t
-          if (options.vhaverefine)   # Do interpolation
-            vapproxtime = (vcnt + 1) / (options.Refine + 2);
-            vapproxvals = (1 - vapproxtime) * vSaveVUForRefine ...
-                          + (vapproxtime) * y(:,end);
-            vapproxtime = s(end) + vapproxtime * dt;
-          else
-            vapproxvals = x(:,end);
-            vapproxtime = t(end);
-          endif
+        t(++istep)  = t_new;
+        x(:, istep) = x_new;
+        iout = istep;
+
+        ## Call event handler on new timestep.
+        ##  Stop integration if veventbreak is true
+        if (options.vhaveeventfunction)
+          solution.vevent = ...
+          odepkg_event_handle (options.Events, t(istep), x(:, istep), [],
+                                   options.vfunarguments{:});
+              if (! isempty (solution.vevent{1})
+                  && solution.vevent{1} == 1)
+                t(istep) = solution.vevent{3}(end);
+                x(:, istep) = solution.vevent{4}(end, :).';
+                solution.vunhandledtermination = false; 
+                break;
+              endif
+        endif
+
+        ## Call plot.  Stop integration if plot function
+        ## returns true.
+        if (options.vhaveoutputfunction)
+          vcnt = options.Refine + 1;
+          vapproxtime = linspace (t_old, t_new, vcnt);
+          vapproxvals = interp1 ([t_old, t_new],
+                                 [x_old, x_new],
+                                 vapproxtime, 'linear');
           if (options.vhaveoutputselection)
             vapproxvals = vapproxvals(options.OutputSel);
           endif
           vpltret = feval (options.OutputFcn, vapproxtime,
                            vapproxvals, [], options.vfunarguments{:});
-          if (vpltret)  # Leave refinement loop
+          if (vpltret)  # Leave main loop
+            solution.vunhandledtermination = false;
             break;
           endif
-        endfor
-        if (vpltret)  # Leave main loop
-          solution.vunhandledtermination = false;
-          break;
         endif
+
       endif
-      
-      ## Call event only if a valid result has been found, therefore this
-      ## code fragment has moved here.  Stop integration if veventbreak is
-      ## true
-      if (options.vhaveeventfunction)
-        solution.vevent = odepkg_event_handle (options.Events, t(end),
-                                               x(:,end), [],
-                                               options.vfunarguments{:});
-        if (! isempty (solution.vevent{1})
-            && solution.vevent{1} == 1)
-          t(solution.vcntloop-1,:) = solution.vevent{3}(end,:);
-          x(:,solution.vcntloop-1) = solution.vevent{4}(end,:)';
-          solution.vunhandledtermination = false; 
-          break;
-        endif
+
+      ## move to next time-step
+      t_old = t_new;
+      x_old = x_new;
+      k_vals = new_k_vals;
+
+      solution.vcntloop = solution.vcntloop + 1;
+      vcntiter = 0;
+            
+    else
+
+      ireject++;
+
+      ## Stop solving because in the last 5000 steps no successful valid
+      ## value has been found
+      if (ireject >= 5000)
+        error (["integrate_adaptive: Solving has not been successful. ",
+                " The iterative integration loop exited at time",
+                " t = %f before endpoint at tend = %f was reached. ",
+                " This happened because the iterative integration loop",
+                " does not find a valid solution at this time",
+                " stamp.  Try to reduce the value of 'InitialStep' and/or",
+                " 'MaxStep' with the command 'odeset'.\n"],
+               t_old, tspan(end));
       endif
-      
-    else
-      
-      facmax = 1.0;
-      
+
     endif
     
     ## Compute next timestep, formula taken from Hairer
-    err += eps;    # adding an eps to avoid divisions by zero
-    dt = dt * min (facmax,
-                   max (facmin, fac * (1 / err)^(1 / (order + 1))));
-    dt = vdirection * min (abs (dt), options.MaxStep);
-    
-    ## Update counters that count the number of iteration cycles
-    solution.vcntcycles += 1;  # Needed for cost statistics
-    vcntiter += 1;  # Needed to find iteration problems
+    err += eps;    # avoid divisions by zero
+    dt *= min (facmax, max (facmin, fac * (1 / err)^(1 / (order + 1))));
+    dt = dir * min (abs (dt), options.MaxStep);    
 
-    ## Stop solving because in the last 1000 steps no successful valid
-    ## value has been found
-    if (vcntiter >= 5000)
-      error (["Solving has not been successful.  The iterative",
-              " integration loop exited at time t = %f before endpoint at",
-              " tend = %f was reached.  This happened because the iterative",
-              " integration loop does not find a valid solution at this time",
-              " stamp.  Try to reduce the value of 'InitialStep' and/or",
-              " 'MaxStep' with the command 'odeset'.\n"],
-             s(end), tspan(end));
-    endif
-
-    ## if this is the last iteration, save the length of last interval
-    if (counter > k)
-      j = length (z);
-    endif
   endwhile
   
   ## Check if integration of the ode has been successful
-  if (vdirection * z(end) < vdirection * tspan(end))
+  if (dir * t(end) < dir * tspan(end))
     if (solution.vunhandledtermination == true)
-      error ("OdePkg:InvalidArgument",
+      error ("integrate_adaptive: InvalidArgument",
              ["Solving has not been successful.  The iterative",
               " integration loop exited at time t = %f",
               " before endpoint at tend = %f was reached.  This may",
-              " happen if the stepsize grows smaller than defined in",
-              " vminstepsize.  Try to reduce the value of 'InitialStep'",
+              " happen if the stepsize grows too small. ",
+              " Try to reduce the value of 'InitialStep'",
               " and/or 'MaxStep' with the command 'odeset'.\n"],
-             z(end), tspan(end));
+             t(end), tspan(end));
     else
-      warning ("OdePkg:InvalidArgument",
+      warning ("integrate_adaptive: InvalidArgument",
                ["Solver has been stopped by a call of 'break' in the main",
                 " iteration loop at time t = %f before endpoint at tend = %f ",
                 " was reached.  This may happen because the @odeplot function",
                 " returned 'true' or the @event function returned",
                 " 'true'.\n"],
-               z(end), tspan(end));
+               t(end), tspan(end));
     endif
   endif
 
-  ## Compute how many values are out of time inerval
-  d = vdirection * t((end-(j-1)):end) > vdirection * tspan(end)*ones (j, 1);
-  f = sum (d);
-
   ## Remove not-requested values of time and solution
-  solution.t = t(1:end-f);
-  solution.x = x(:,1:end-f)';
+  solution.t = t;
+  solution.x = x.';
   
 endfunction
 
--- a/scripts/ode/private/runge_kutta_45_dorpri.m	Sat Oct 10 16:52:59 2015 -0700
+++ b/scripts/ode/private/runge_kutta_45_dorpri.m	Sun Oct 11 18:44:58 2015 +0200
@@ -58,7 +58,8 @@
 ## @seealso{odepkg}
 
 function [t_out, x_out, x_est, k] = ...
-         runge_kutta_45_dorpri (f, t, x, dt, varargin)
+         runge_kutta_45_dorpri (f, t, x, dt, opts = [], k_vals = [],
+                                t_out = t + dt)
 
   persistent a = [0           0          0           0        0          0;
                   1/5         0          0           0        0          0;
@@ -80,17 +81,18 @@
   aa = dt * a;
   k = zeros (rows (x), 7);
 
-  if (nargin >= 5)  # options are passed
-    args = varargin{1}.vfunarguments;
-    if (nargin >= 6)  # both the options and the k values are passed
-      k(:,1) = varargin{2}(:,end);  # FSAL property
-    else      
-      k(:,1) = feval (f, t, x, args{:});
-    endif
+  if (! isempty (opts)) # extra arguments for function evaluator
+    args = opts.vfunarguments;
   else
     args = {};
   endif
 
+  if (! isempty (k_vals))  # k values from previous step are passed
+    k(:,1) = k_vals(:,end);  # FSAL property
+  else      
+    k(:,1) = feval (f, t, x, args{:});
+  endif
+    
   k(:,2) = feval (f, s(2), x + k(:,1) * aa(2, 1).', args{:});
   k(:,3) = feval (f, s(3), x + k(:,1:2) * aa(3, 1:2).', args{:});
   k(:,4) = feval (f, s(4), x + k(:,1:3) * aa(4, 1:3).', args{:});
@@ -98,13 +100,13 @@
   k(:,6) = feval (f, s(6), x + k(:,1:5) * aa(6, 1:5).', args{:});
 
   ## compute new time and new values for the unknowns
-  t_out = t + dt;
+  ## t_out = t + dt;
   x_out = x + k(:,1:6) * cc(:);  # 5th order approximation
 
   ## if the estimation of the error is required
   if (nargout >= 3)
     ## new solution to be compared with the previous one
-    k(:,7) = feval (f, t + dt, x_out, args{:});
+    k(:,7) = feval (f, t_out, x_out, args{:});
     cc_prime = dt * c_prime;
     x_est = x + k * cc_prime(:);
   endif