diff src/DLD-FUNCTIONS/dassl.cc @ 5729:e065f7c18bdc

[project @ 2006-04-03 19:03:30 by jwe]
author jwe
date Mon, 03 Apr 2006 19:03:31 +0000
parents 95d90f781ca8
children 080c08b192d8
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/dassl.cc	Mon Apr 03 18:57:51 2006 +0000
+++ b/src/DLD-FUNCTIONS/dassl.cc	Mon Apr 03 19:03:31 2006 +0000
@@ -37,6 +37,7 @@
 #include "gripes.h"
 #include "oct-obj.h"
 #include "ov-fcn.h"
+#include "ov-cell.h"
 #include "pager.h"
 #include "unwind-prot.h"
 #include "utils.h"
@@ -208,9 +209,10 @@
 row of the output @var{x} is @var{x_0} and the first row\n\
 of the output @var{xdot} is @var{xdot_0}.\n\
 \n\
-The first argument, @var{fcn}, is a string that names the function to\n\
-call to compute the vector of residuals for the set of equations.\n\
-It must have the form\n\
+The first argument, @var{fcn}, is a string or a two element cell array\n\
+of strings, inline or function handle, that names the function, to call\n\
+to compute the vector of residuals for the set of equations. It must\n\
+have the form\n\
 \n\
 @example\n\
 @var{res} = f (@var{x}, @var{xdot}, @var{t})\n\
@@ -291,43 +293,112 @@
 
   if (nargin > 3 && nargin < 6 && nargout < 5)
     {
+      std::string fcn_name, fname, jac_name, jname;
       dassl_fcn = 0;
       dassl_jac = 0;
 
       octave_value f_arg = args(0);
 
-      switch (f_arg.rows ())
-	{
-	case 1:
-	  dassl_fcn = extract_function
-	    (f_arg, "dassl", "__dassl_fcn__",
-	     "function res = __dassl_fcn__ (x, xdot, t) res = ",
-	     "; endfunction");
-	  break;
+      if (f_arg.is_cell ())
+  	{
+	  Cell c = f_arg.cell_value ();
+	  if (c.length() == 1)
+	    f_arg = c(0);
+	  else if (c.length() == 2)
+	    {
+	      if (c(0).is_function_handle () || c(0).is_inline_function ())
+		dassl_fcn = c(0).function_value ();
+	      else
+		{
+		  fcn_name = unique_symbol_name ("__dassl_fcn__");
+		  fname = "function y = ";
+		  fname.append (fcn_name);
+		  fname.append (" (x, xdot, t) y = ");
+		  dassl_fcn = extract_function
+		    (c(0), "dassl", fcn_name, fname, "; endfunction");
+		}
+	      
+	      if (dassl_fcn)
+		{
+		  if (c(1).is_function_handle () || c(1).is_inline_function ())
+		    dassl_jac = c(1).function_value ();
+		  else
+		    {
+			jac_name = unique_symbol_name ("__dassl_jac__");
+			jname = "function jac = ";
+			jname.append(jac_name);
+			jname.append (" (x, xdot, t, cj) jac = ");
+			dassl_jac = extract_function
+			  (c(1), "dassl", jac_name, jname, "; endfunction");
 
-	case 2:
-	  {
-	    string_vector tmp = f_arg.all_strings ();
+			if (!dassl_jac)
+			  {
+			    if (fcn_name.length())
+			      clear_function (fcn_name);
+			    dassl_fcn = 0;
+			  }
+		    }
+		}
+	    }
+	  else
+	    DASSL_ABORT1 ("incorrect number of elements in cell array");
+	}
 
-	    if (! error_state)
-	      {
-		dassl_fcn = extract_function
-		  (tmp(0), "dassl", "__dassl_fcn__",
-		   "function res = __dassl_fcn__ (x, xdot, t) res = ",
-		   "; endfunction");
+      if (!dassl_fcn && ! f_arg.is_cell())
+	{
+	  if (f_arg.is_function_handle () || f_arg.is_inline_function ())
+	    dassl_fcn = f_arg.function_value ();
+	  else
+	    {
+	      switch (f_arg.rows ())
+		{
+		case 1:
+		  do
+		    {
+		      fcn_name = unique_symbol_name ("__dassl_fcn__");
+		      fname = "function y = ";
+		      fname.append (fcn_name);
+		      fname.append (" (x, xdot, t) y = ");
+		      dassl_fcn = extract_function
+			(f_arg, "dassl", fcn_name, fname, "; endfunction");
+		    }
+		  while (0);
+		  break;
 
-		if (dassl_fcn)
+		case 2:
 		  {
-		    dassl_jac = extract_function
-		      (tmp(1), "dassl", "__dassl_jac__",
-		       "function jac = __dassl_jac__ (x, xdot, t, cj) jac = ",
-		       "; endfunction");
+		    string_vector tmp = f_arg.all_strings ();
+
+		    if (! error_state)
+		      {
+			fcn_name = unique_symbol_name ("__dassl_fcn__");
+			fname = "function y = ";
+			fname.append (fcn_name);
+			fname.append (" (x, xdot, t) y = ");
+			dassl_fcn = extract_function
+			  (tmp(0), "dassl", fcn_name, fname, "; endfunction");
 
-		    if (! dassl_jac)
-		      dassl_fcn = 0;
+			if (dassl_fcn)
+			  {
+			    jac_name = unique_symbol_name ("__dassl_jac__");
+			    jname = "function jac = ";
+			    jname.append(jac_name);
+			    jname.append (" (x, xdot, t, cj) jac = ");
+			    dassl_jac = extract_function
+			      (tmp(1), "dassl", jac_name, jname, 
+			       "; endfunction");
+
+			    if (!dassl_jac)
+			      {
+				if (fcn_name.length())
+				  clear_function (fcn_name);
+				dassl_fcn = 0;
+			      }
+			  }
+		      }
 		  }
-	      }
-	  }
+		}
+	    }
 	}
 
       if (error_state || ! dassl_fcn)
@@ -381,6 +452,11 @@
       else
 	output = dae.integrate (out_times, deriv_output);
 
+      if (fcn_name.length())
+	clear_function (fcn_name);
+      if (jac_name.length())
+	clear_function (jac_name);
+
       if (! error_state)
 	{
 	  std::string msg = dae.error_message ();