diff src/DLD-FUNCTIONS/daspk.cc @ 4140:303b28a7a7e4

[project @ 2002-11-01 02:53:13 by jwe]
author jwe
date Fri, 01 Nov 2002 02:53:14 +0000
parents 19a1626b8d57
children b02ada83de67
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/daspk.cc	Fri Nov 01 00:49:13 2002 +0000
+++ b/src/DLD-FUNCTIONS/daspk.cc	Fri Nov 01 02:53:14 2002 +0000
@@ -46,6 +46,13 @@
 // Global pointer for user defined function required by daspk.
 static octave_function *daspk_fcn;
 
+// Global pointer for optional user defined jacobian function.
+static octave_function *daspk_jac;
+
+// Have we warned about imaginary values returned from user function?
+static bool warned_fcn_imaginary = false;
+static bool warned_jac_imaginary = false;
+
 // Is this a recursive call?
 static int call_depth = 0;
 
@@ -99,6 +106,12 @@
       int tlen = tmp.length ();
       if (tlen > 0 && tmp(0).is_defined ())
 	{
+	  if (! warned_fcn_imaginary && tmp(0).is_complex_type ())
+	    {
+	      warning ("daspk: ignoring imaginary part returned from user-supplied function");
+	      warned_fcn_imaginary = true;
+	    }
+
 	  retval = ColumnVector (tmp(0).vector_value ());
 
 	  if (tlen > 1)
@@ -114,6 +127,76 @@
   return retval;
 }
 
+Matrix
+daspk_user_jacobian (const ColumnVector& x, const ColumnVector& xdot,
+		     double t, double cj)
+{
+  Matrix retval;
+
+  int nstates = x.capacity ();
+
+  assert (nstates == xdot.capacity ());
+
+  octave_value_list args;
+
+  args(3) = cj;
+  args(2) = t;
+
+  if (nstates > 1)
+    {
+      Matrix m1 (nstates, 1);
+      Matrix m2 (nstates, 1);
+      for (int i = 0; i < nstates; i++)
+	{
+	  m1 (i, 0) = x (i);
+	  m2 (i, 0) = xdot (i);
+	}
+      octave_value state (m1);
+      octave_value deriv (m2);
+      args(1) = deriv;
+      args(0) = state;
+    }
+  else
+    {
+      double d1 = x (0);
+      double d2 = xdot (0);
+      octave_value state (d1);
+      octave_value deriv (d2);
+      args(1) = deriv;
+      args(0) = state;
+    }
+
+  if (daspk_jac)
+    {
+      octave_value_list tmp = daspk_jac->do_multi_index_op (1, args);
+
+      if (error_state)
+	{
+	  gripe_user_supplied_eval ("daspk");
+	  return retval;
+	}
+
+      int tlen = tmp.length ();
+      if (tlen > 0 && tmp(0).is_defined ())
+	{
+	  if (! warned_jac_imaginary && tmp(0).is_complex_type ())
+	    {
+	      warning ("daspk: ignoring imaginary part returned from user-supplied jacobian function");
+	      warned_jac_imaginary = true;
+	    }
+
+	  retval = tmp(0).matrix_value ();
+
+	  if (error_state || retval.length () == 0)
+	    gripe_user_supplied_eval ("daspk");
+	}
+      else
+	gripe_user_supplied_eval ("daspk");
+    }
+
+  return retval;
+}
+
 #define DASPK_ABORT() \
   do \
     { \
@@ -235,6 +318,9 @@
 {
   octave_value_list retval;
 
+  warned_fcn_imaginary = false;
+  warned_jac_imaginary = false;
+
   unwind_protect::begin_frame ("Fdaspk");
 
   unwind_protect_int (call_depth);
@@ -247,12 +333,46 @@
 
   if (nargin > 3 && nargin < 6)
     {
-      daspk_fcn = extract_function
-	(args(0), "daspk", "__daspk_fcn__",
-	 "function res = __daspk_fcn__ (x, xdot, t) res = ",
-	 "; endfunction");
+      daspk_fcn = 0;
+      daspk_jac = 0;
+
+      octave_value f_arg = args(0);
+
+      switch (f_arg.rows ())
+	{
+	case 1:
+	  daspk_fcn = extract_function
+	    (args(0), "daspk", "__daspk_fcn__",
+	     "function res = __daspk_fcn__ (x, xdot, t) res = ",
+	     "; endfunction");
+	  break;
+
+	case 2:
+	  {
+	    string_vector tmp = f_arg.all_strings ();
 
-      if (! daspk_fcn)
+	    if (! error_state)
+	      {
+		daspk_fcn = extract_function
+		  (tmp(0), "daspk", "__daspk_fcn__",
+		   "function res = __daspk_fcn__ (x, xdot, t) res = ",
+		   "; endfunction");
+
+		if (daspk_fcn)
+		  {
+		    daspk_jac = extract_function
+		      (tmp(1), "daspk", "__daspk_jac__",
+		       "function jac = __daspk_jac__ (x, xdot, t, cj) jac = ",
+		       "; endfunction");
+
+		    if (! daspk_jac)
+		      daspk_fcn = 0;
+		  }
+	      }
+	  }
+	}
+
+      if (error_state || ! daspk_fcn)
 	DASPK_ABORT ();
 
       ColumnVector state = ColumnVector (args(1).vector_value ());
@@ -288,6 +408,9 @@
       double tzero = out_times (0);
 
       DAEFunc func (daspk_user_function);
+      if (daspk_jac)
+	func.set_jacobian_function (daspk_user_jacobian);
+
       DASPK dae (state, deriv, tzero, func);
       dae.set_options (daspk_opts);