diff libinterp/dldfcn/__ode15__.cc @ 29296:7bf91e98bfc6

__ode15__(): Consider the correct number of input arguments in the event callback (bug #59477). * ode15i.m: Provide three input arguments (t, y, yp) in the event callback. * ode15s.m: Provide two input arguments (t, y) in the event callback. * __ode15__.cc: Provide the number of input arguments in the event callback.
author Markus Meisinger <chloros2@gmx.de>
date Mon, 11 Jan 2021 20:11:45 +0100
parents 8f67ad8b3103
children 7854d5752dd2
line wrap: on
line diff
--- a/libinterp/dldfcn/__ode15__.cc	Thu Jan 14 01:15:45 2021 -0500
+++ b/libinterp/dldfcn/__ode15__.cc	Mon Jan 11 20:11:45 2021 +0100
@@ -309,7 +309,8 @@
                  const octave_value& event_fcn, ColumnVector& te,
                  Matrix& ye, ColumnVector& ie, ColumnVector& oldval,
                  ColumnVector& oldisterminal, ColumnVector& olddir,
-                 octave_idx_type& temp, ColumnVector& yold);
+                 octave_idx_type& temp, ColumnVector& yold,
+                 const octave_idx_type num_event_args);
 
     bool
     outputfun (const octave_value& output_fcn, bool haveoutputsel,
@@ -324,7 +325,8 @@
            const ColumnVector& yp, ColumnVector& oldval,
            ColumnVector& oldisterminal, ColumnVector& olddir,
            octave_idx_type cont, octave_idx_type& temp, realtype told,
-           ColumnVector& yold);
+           ColumnVector& yold,
+           const octave_idx_type num_event_args);
 
     void set_maxorder (int maxorder);
 
@@ -334,7 +336,8 @@
                const int refine, bool haverefine, bool haveoutputfcn,
                const octave_value& output_fcn, bool haveoutputsel,
                ColumnVector& outputsel, bool haveeventfunction,
-               const octave_value& event_fcn);
+               const octave_value& event_fcn,
+               const octave_idx_type num_event_args);
 
     void print_stat (void);
 
@@ -591,7 +594,8 @@
                   const int refine, bool haverefine, bool haveoutputfcn,
                   const octave_value& output_fcn, bool haveoutputsel,
                   ColumnVector& outputsel, bool haveeventfunction,
-                  const octave_value& event_fcn)
+                  const octave_value& event_fcn,
+                  const octave_idx_type num_event_args)
   {
     // Set up output
     ColumnVector tout, yout (m_num), ypout (m_num), ysel (outputsel.numel ());
@@ -618,7 +622,7 @@
     if (haveeventfunction)
       status = IDA::event (event_fcn, te, ye, ie, tsol, y,
                            "init", yp, oldval, oldisterminal,
-                           olddir, cont, temp, tsol, yold);
+                           olddir, cont, temp, tsol, yold, num_event_args);
 
     if (numt > 2)
       {
@@ -652,7 +656,8 @@
             if (haveeventfunction)
               status = IDA::event (event_fcn, te, ye, ie, tout(j), yout,
                                    string, ypout, oldval, oldisterminal,
-                                   olddir, j, temp, tout(j-1), yold);
+                                   olddir, j, temp, tout(j-1), yold,
+                                   num_event_args);
 
             // If integration is stopped, return only the reached steps
             if (status == 1)
@@ -689,7 +694,8 @@
                                          output_fcn, outputsel,
                                          haveeventfunction, event_fcn, te,
                                          ye, ie, oldval, oldisterminal,
-                                         olddir, temp, yold);
+                                         olddir, temp, yold,
+                                         num_event_args);
 
             ypout = NVecToCol (yyp, m_num);
             cont += 1;
@@ -708,7 +714,8 @@
             if (haveeventfunction && ! haverefine && tout(cont) < tend)
               status = IDA::event (event_fcn, te, ye, ie, tout(cont), yout,
                                    string, ypout, oldval, oldisterminal,
-                                   olddir, cont, temp, tout(cont-1), yold);
+                                   olddir, cont, temp, tout(cont-1), yold,
+                                   num_event_args);
           }
 
         if (status == 0)
@@ -736,7 +743,7 @@
                   status = IDA::event (event_fcn, te, ye, ie, tend, yout,
                                        string, ypout, oldval, oldisterminal,
                                        olddir, cont, temp, tout(cont-1),
-                                       yold);
+                                       yold, num_event_args);
               }
 
             N_VDestroy_Serial (dky);
@@ -759,11 +766,16 @@
               const ColumnVector& yp, ColumnVector& oldval,
               ColumnVector& oldisterminal, ColumnVector& olddir,
               octave_idx_type cont, octave_idx_type& temp, realtype told,
-              ColumnVector& yold)
+              ColumnVector& yold,
+              const octave_idx_type num_event_args)
   {
     bool status = 0;
 
-    octave_value_list args = ovl (tsol, y, yp);
+    octave_value_list args;
+    if (num_event_args == 2)
+      args = ovl (tsol, y);
+    else
+      args = ovl (tsol, y, yp);
 
     // cont is the number of steps reached by the solver
     // temp is the number of events registered
@@ -867,7 +879,8 @@
                     const octave_value& event_fcn, ColumnVector& te,
                     Matrix& ye, ColumnVector& ie, ColumnVector& oldval,
                     ColumnVector& oldisterminal, ColumnVector& olddir,
-                    octave_idx_type& temp, ColumnVector& yold)
+                    octave_idx_type& temp, ColumnVector& yold,
+                    const octave_idx_type num_event_args)
   {
     realtype h = 0, tcur = 0;
     bool status = 0;
@@ -919,7 +932,7 @@
           status = IDA::event (event_fcn, te, ye, ie, tout(cont),
                                yout, string, ypout, oldval,
                                oldisterminal, olddir, cont, temp,
-                               tout(cont-1), yold);
+                               tout(cont-1), yold, num_event_args);
       }
 
     N_VDestroy_Serial (dky);
@@ -1091,7 +1104,8 @@
             const realtype t0,
             const ColumnVector& y0,
             const ColumnVector& yp0,
-            const octave_scalar_map& options)
+            const octave_scalar_map& options,
+            const octave_idx_type num_event_args)
   {
     octave_value_list retval;
 
@@ -1217,7 +1231,7 @@
     retval = dae.integrate (numt, tspan, y0, yp0, refine,
                             haverefine, haveoutputfunction,
                             output_fcn, haveoutputsel, outputsel,
-                            haveeventfunction, event_fcn);
+                            haveeventfunction, event_fcn, num_event_args);
 
     // Statistics
     bool havestats = options.getfield ("havestats").bool_value ();
@@ -1233,7 +1247,7 @@
 
 DEFUN_DLD (__ode15__, args, ,
            doc: /* -*- texinfo -*-
-@deftypefn {} {@var{t}, @var{y} =} __ode15__ (@var{fun}, @var{tspan}, @var{y0}, @var{yp0}, @var{options})
+@deftypefn {} {@var{t}, @var{y} =} __ode15__ (@var{fun}, @var{tspan}, @var{y0}, @var{yp0}, @var{options}, @var{num_event_args})
 Undocumented internal function.
 @end deftypefn */)
 {
@@ -1241,7 +1255,7 @@
 #if defined (HAVE_SUNDIALS)
 
   // Check number of parameters
-  if (args.length () != 5)
+  if (args.length () != 6)
     print_usage ();
 
   // Check odefun
@@ -1283,9 +1297,15 @@
   octave_scalar_map options
     = args(4).xscalar_map_value ("__ode15__: OPTS argument must be a scalar structure");
 
+  // Provided number of arguments in the ode callback function
+  octave_idx_type num_event_args
+    = args(5).xidx_type_value ("__ode15__: NUM_EVENT_ARGS must be an integer");
 
-  return octave::do_ode15 (ida_fcn, tspan, numt, t0, y0, yp0, options);
+  if (num_event_args != 2 && num_event_args != 3)
+    error ("__ode15__: number of input arguments in event callback must be 2 or 3");
 
+  return octave::do_ode15 (ida_fcn, tspan, numt, t0, y0, yp0, options,
+                           num_event_args);
 
 #else