changeset 237:418a5119047b

New interface for function evaluation
author Eugenio Gianniti <eugenio.gianniti@mail.polimi.it>
date Mon, 23 Jun 2014 20:28:28 +0200
parents a51c09492e30
children b96f6b12f8ca
files src/feval.cc src/function.h
diffstat 2 files changed, 101 insertions(+), 76 deletions(-) [+]
line wrap: on
line diff
--- a/src/feval.cc	Mon Jun 23 17:06:26 2014 +0200
+++ b/src/feval.cc	Mon Jun 23 20:28:28 2014 +0200
@@ -17,89 +17,115 @@
 
 #include "function.h"
 #include <stdexcept>
+#include <vector>
 
-DEFUN_DLD (feval, args, , "-*- texinfo -*-\n\
-@deftypefn {Function File} {[@var{value}]} = \
-feval (@var{function_name}, @var{Coordinate})\n\
+DEFUN_DLD (feval, args, nargout, "-*- texinfo -*-\n\
+@deftypefn {Function File} {[@var{fx}, @var{fy}, @var{fz}]} = \
+feval (@var{function_name}, @var{x}, @var{y}, @var{z})\n\
 Evaluate a function at a specific point of the domain and return the value. \n\
-The input parameters are the function and the point where it has to\
-be evaluated.\n\
-The point can be either a vector or a matrix. If it is a matrix, each column\
-represents a different point at which we evaluates the function.\n\
+The input parameters are the function and the coordinates of the point where \n\
+it has to be evaluated.\n\n\
+Be aware that the number of input arguments and outputs depends on the \n\
+dimensions of its domain and codomain.\n\
+@example\n\
+value = feval (p, x , y)\n\
+@end example\n\
+This is the expected call for a scalar field on a bidimensional domain.\n\
 @seealso{Function}\n\
 @end deftypefn")
 {
 
   int nargin = args.length ();
-  octave_value retval=0;
+  octave_value_list retval;
   
-  if (nargin < 2 || nargin > 2)
-    print_usage ();
-  else
+  if (! function_type_loaded)
     {
-      if (! function_type_loaded)
-        {
-          function::register_type ();
-          function_type_loaded = true;
-          mlock ();
-        }
+      function::register_type ();
+      function_type_loaded = true;
+      mlock ();
+    }
 
-      if (args(0).type_id () == function::static_type_id ())
+  if (args(0).type_id () == function::static_type_id ())
+    {
+      const function & fspo =
+        static_cast<const function&> (args(0).get_rep ());
+
+      if (!error_state)
         {
-          const function & fspo =
-            static_cast<const function&> (args(0).get_rep ());
-          Matrix point= args(1).matrix_value ();
-
-          if (!error_state)
-            {
-              const boost::shared_ptr<const dolfin::Function> 
-                & f = fspo.get_pfun ();
+          const boost::shared_ptr<const dolfin::Function> 
+            & f = fspo.get_pfun ();
 
-              if (point.rows () == 1)
-                point = point.transpose ();
+          octave_idx_type pdim = f->geometric_dimension ();
+          if (nargin != pdim + 1)
+            print_usage ();
+          else
+            {
+              std::vector <Matrix> coordinates;
+              dim_vector dims;
 
-              if (point.rows () != f->geometric_dimension ())
-                {
-                  error ("feval: wrong coordinates dimension");
-                }
-              else
+              for (octave_idx_type in = 1; in <= pdim; ++in)
                 {
-                  dim_vector dims;
-                  dims.resize (2);
-                  dims(0) = f->value_dimension (0);
-                  dims(1) = point.cols ();
-                  Matrix res (dims);
-
-                  dim_vector dims_tmp;
-                  dims_tmp.resize (2);
-                  dims_tmp(0) = f->value_dimension (0);
-                  dims_tmp(1) = 1;
-                  Array<double> res_tmp (dims_tmp);
-
-                  for (uint i = 0; i < point.cols (); ++i)
+                  if (! args(in).is_real_type ())
+                    error ("invalid argument");
+                  else
                     {
-                      Array<double> point_tmp = point.column (i);
-                      dolfin::Array<double> 
-                        x (point_tmp.length (), point_tmp.fortran_vec ());
-                      dolfin::Array<double> 
-                        values (res_tmp.length (), res_tmp.fortran_vec ());
-                      try
+                      Matrix aux = args(in).matrix_value ();
+                      if (in == 1)
+                        dims = aux.dims ();
+                      else
                         {
-                          f->eval (values, x);
+                          dim_vector newdims = aux.dims ();
+                          if (dims != newdims)
+                            {
+                              std::string msg = "all the input matrices should";
+                              msg += " have the same size";
+                              error (msg.c_str ());
+                            }
                         }
-                      catch (std::runtime_error & err)
-                        {
-                          std::string msg = "feval: cannot evaluate a function";
-                          msg += " outside of its domain";
-                          error (msg.c_str ());
-                        }
-                      for (uint j = 0; j < res_tmp.length (); ++j)
-                        res(i, j) = res_tmp (j);
+                      coordinates.push_back (aux);
                     }
-                  retval = octave_value (res);
                 }
+
+              octave_idx_type vdim = f->value_dimension (0);
+              if (nargout != vdim)
+                error ("wrong number of output arguments");
+
+              std::vector <Matrix> evaluations;
+              for (octave_idx_type out = 0; out < vdim; ++out)
+                evaluations.push_back (Matrix (dims));
+
+              for (octave_idx_type k = 0; k < dims.numel (); ++k)
+                {
+                  Array<double> point (dim_vector (pdim, 1));
+                  for (octave_idx_type el = 0; el < pdim; ++el)
+                    point (el) = coordinates[el] (k);
+                  dolfin::Array<double> 
+                    x (point.length (), point.fortran_vec ());
+
+                  Array<double> res (dim_vector (vdim, 1));
+                  dolfin::Array<double> 
+                    values (res.length (), res.fortran_vec ());
+                  try
+                    {
+                      f->eval (values, x);
+                    }
+                  catch (std::runtime_error & err)
+                    {
+                      std::string msg = "cannot evaluate a function";
+                      msg += " outside of its domain";
+                      error (msg.c_str ());
+                    }
+
+                  for (octave_idx_type el = 0; el < vdim; ++el)
+                    evaluations[el] (k) = res (el);
+                }
+
+              for (std::vector<Matrix>::iterator it = evaluations.begin ();
+                   it != evaluations.end (); ++it)
+                retval.append (octave_value (*it));
             }
         }
     }
+
   return retval;
 }
--- a/src/function.h	Mon Jun 23 17:06:26 2014 +0200
+++ b/src/function.h	Mon Jun 23 20:28:28 2014 +0200
@@ -74,15 +74,24 @@
            const std::list<octave_value_list>& idx)
     {
       octave_value retval;
+      retval = subsref (type, idx, 1);
+      return retval;
+    }
+
+  octave_value_list
+  subsref (const std::string& type,
+           const std::list<octave_value_list>& idx,
+           int nargout)
+    {
+      octave_value_list retval;
 
       switch (type[0])
         {
         case '(':
           {
-            std::list<octave_value_list> args;
-            args.push_back (octave_value (new function (*this)));
-            args.push_back (idx.front ());
-            retval = feval ("feval", args);
+            std::list<octave_value_list> args (idx);
+            args.push_front (octave_value (new function (*this)));
+            retval = feval ("feval", args, nargout);
           }
           break;
 
@@ -101,16 +110,6 @@
       return retval;
     }
 
-  octave_value_list
-  subsref (const std::string& type,
-           const std::list<octave_value_list>& idx,
-           int)
-    {
-      octave_value_list retval;
-      retval = subsref (type, idx);
-      return retval;
-    }
-
  private:
 
   std::string str;