diff src/DLD-FUNCTIONS/cellfun.cc @ 9450:cf714e75c656

implement overloaded function handles
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 23 Jul 2009 14:44:30 +0200
parents 610bf90fce2a
children cb0b21f34abc
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/cellfun.cc	Wed Jul 22 15:11:04 2009 +0200
+++ b/src/DLD-FUNCTIONS/cellfun.cc	Thu Jul 23 14:44:30 2009 +0200
@@ -267,8 +267,6 @@
 @end deftypefn")
 {
   octave_value_list retval;
-  std::string name = "function";
-  octave_function *func = 0;
   int nargin = args.length ();
   nargout = (nargout < 1 ? 1 : nargout);
 
@@ -279,20 +277,7 @@
       return retval;
     }
 
-  if (args(0).is_function_handle () || args(0).is_inline_function ())
-    {
-      func = args(0).function_value ();
-
-      if (error_state)
-	return retval;
-    }
-  else if (args(0).is_string ())
-    name = args(0).string_value ();
-  else
-    {
-      error ("cellfun: first argument must be a string or function handle");
-      return retval;
-    }	
+  octave_value func = args(0);
 
   if (! args(1).is_cell ())
     {
@@ -305,334 +290,327 @@
   
   octave_idx_type k = f_args.numel ();
 
-  if (name == "isempty")
-    {      
-      boolNDArray result (f_args.dims ());
-      for (octave_idx_type count = 0; count < k ; count++)
-        result(count) = f_args.elem(count).is_empty ();
-      retval(0) = result;
-    }
-  else if (name == "islogical")
-    {
-      boolNDArray result (f_args.dims ());
-      for (octave_idx_type  count= 0; count < k ; count++)
-        result(count) = f_args.elem(count).is_bool_type ();
-      retval(0) = result;
-    }
-  else if (name == "isreal")
-    {
-      boolNDArray result (f_args.dims ());
-      for (octave_idx_type  count= 0; count < k ; count++)
-        result(count) = f_args.elem(count).is_real_type ();
-      retval(0) = result;
-    }
-  else if (name == "length")
-    {
-      NDArray result (f_args.dims ());
-      for (octave_idx_type  count= 0; count < k ; count++)
-        result(count) = static_cast<double> (f_args.elem(count).length ());
-      retval(0) = result;
-    }
-  else if (name == "ndims")
+  if (func.is_string ())
     {
-      NDArray result (f_args.dims ());
-      for (octave_idx_type count = 0; count < k ; count++)
-        result(count) = static_cast<double> (f_args.elem(count).ndims ());
-      retval(0) = result;
-    }
-  else if (name == "prodofsize" || name == "numel")
-    {
-      NDArray result (f_args.dims ());
-      for (octave_idx_type count = 0; count < k ; count++)
-        result(count) = static_cast<double> (f_args.elem(count).numel ());
-      retval(0) = result;
-    }
-  else if (name == "size")
-    {
-      if (nargin == 3)
+      std::string name = func.string_value ();
+      if (name.find_first_of ("(x)") != std::string::npos)
+        warning ("cellfun: passing function body as string is no longer supported."
+                 " Use @ or `inline'.");
+
+      if (name == "isempty")
+        {      
+          boolNDArray result (f_args.dims ());
+          for (octave_idx_type count = 0; count < k ; count++)
+            result(count) = f_args.elem(count).is_empty ();
+          retval(0) = result;
+        }
+      else if (name == "islogical")
+        {
+          boolNDArray result (f_args.dims ());
+          for (octave_idx_type  count= 0; count < k ; count++)
+            result(count) = f_args.elem(count).is_bool_type ();
+          retval(0) = result;
+        }
+      else if (name == "isreal")
+        {
+          boolNDArray result (f_args.dims ());
+          for (octave_idx_type  count= 0; count < k ; count++)
+            result(count) = f_args.elem(count).is_real_type ();
+          retval(0) = result;
+        }
+      else if (name == "length")
+        {
+          NDArray result (f_args.dims ());
+          for (octave_idx_type  count= 0; count < k ; count++)
+            result(count) = static_cast<double> (f_args.elem(count).length ());
+          retval(0) = result;
+        }
+      else if (name == "ndims")
         {
-          int d = args(2).nint_value () - 1;
-
-          if (d < 0)
-	    error ("cellfun: third argument must be a positive integer");
-
-	  if (! error_state)
+          NDArray result (f_args.dims ());
+          for (octave_idx_type count = 0; count < k ; count++)
+            result(count) = static_cast<double> (f_args.elem(count).ndims ());
+          retval(0) = result;
+        }
+      else if (name == "prodofsize" || name == "numel")
+        {
+          NDArray result (f_args.dims ());
+          for (octave_idx_type count = 0; count < k ; count++)
+            result(count) = static_cast<double> (f_args.elem(count).numel ());
+          retval(0) = result;
+        }
+      else if (name == "size")
+        {
+          if (nargin == 3)
             {
-              NDArray result (f_args.dims ());
-              for (octave_idx_type count = 0; count < k ; count++)
+              int d = args(2).nint_value () - 1;
+
+              if (d < 0)
+                error ("cellfun: third argument must be a positive integer");
+
+              if (! error_state)
                 {
-                  dim_vector dv = f_args.elem(count).dims ();
-                  if (d < dv.length ())
-	            result(count) = static_cast<double> (dv(d));
-                  else
-	            result(count) = 1.0;
+                  NDArray result (f_args.dims ());
+                  for (octave_idx_type count = 0; count < k ; count++)
+                    {
+                      dim_vector dv = f_args.elem(count).dims ();
+                      if (d < dv.length ())
+                        result(count) = static_cast<double> (dv(d));
+                      else
+                        result(count) = 1.0;
+                    }
+                  retval(0) = result;
                 }
+            }
+          else
+            error ("not enough arguments for `size'");
+        }
+      else if (name == "isclass")
+        {
+          if (nargin == 3)
+            {
+              std::string class_name = args(2).string_value();
+              boolNDArray result (f_args.dims ());
+              for (octave_idx_type count = 0; count < k ; count++)
+                result(count) = (f_args.elem(count).class_name() == class_name);
+
               retval(0) = result;
             }
+          else
+            error ("not enough arguments for `isclass'");
         }
       else
-        error ("not enough arguments for `size'");
+        {
+          func = symbol_table::find_function (name);
+          if (func.is_undefined ())
+            error ("cellfun: invalid function name: %s", name.c_str ());
+        }
     }
-  else if (name == "isclass")
-    {
-      if (nargin == 3)
-        {
-          std::string class_name = args(2).string_value();
-          boolNDArray result (f_args.dims ());
-          for (octave_idx_type count = 0; count < k ; count++)
-            result(count) = (f_args.elem(count).class_name() == class_name);
-          
-          retval(0) = result;
-        }
-      else
-        error ("not enough arguments for `isclass'");
-    }
-  else 
+
+  if (error_state || ! retval.empty ())
+    return retval;
+
+  if (func.is_function_handle () || func.is_inline_function ()
+      || func.is_function ())
     {
       unwind_protect::frame_id_t uwp_frame = unwind_protect::begin_frame ();
       unwind_protect::protect_var (buffer_error_messages);
 
-      std::string fcn_name;
-      
-      if (! func)
-	{
-	  fcn_name = unique_symbol_name ("__cellfun_fcn_");
-	  std::string fname = "function y = ";
-	  fname.append (fcn_name);
-	  fname.append ("(x) y = ");
-	  func = extract_function (args(0), "cellfun", fcn_name, fname,
-				       "; endfunction");
-	}
+      octave_value_list inputlist;
+      bool uniform_output = true;
+      octave_value error_handler;
+      int offset = 1;
+      int i = 1;
+      OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin);
+      // This is to prevent copy-on-write.
+      const Cell *cinputs = inputs;
 
-      if (! func)
-	error ("unknown function");
-      else
-	{
-	  octave_value_list inputlist;
-	  bool uniform_output = true;
-	  bool have_error_handler = false;
-	  std::string err_name;
-	  octave_function *error_handler = 0;
-	  int offset = 1;
-	  int i = 1;
-	  OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin);
-          // This is to prevent copy-on-write.
-          const Cell *cinputs = inputs;
+      while (i < nargin)
+        {
+          if (args(i).is_string())
+            {
+              std::string arg = args(i++).string_value();
+              if (i == nargin)
+                {
+                  error ("cellfun: parameter value is missing");
+                  goto cellfun_err;
+                }
 
-	  while (i < nargin)
-	    {
-	      if (args(i).is_string())
-		{
-		  std::string arg = args(i++).string_value();
-		  if (i == nargin)
-		    {
-		      error ("cellfun: parameter value is missing");
-		      goto cellfun_err;
-		    }
-
-		  std::transform (arg.begin (), arg.end (), 
-				  arg.begin (), tolower);
-
-		  if (arg == "uniformoutput")
-		    uniform_output = args(i++).bool_value();
-		  else if (arg == "errorhandler")
-		    {
-		      if (args(i).is_function_handle () || 
-			  args(i).is_inline_function ())
-			{
-			  error_handler = args(i).function_value ();
+              std::transform (arg.begin (), arg.end (), 
+                              arg.begin (), tolower);
 
-			  if (error_state)
-			    goto cellfun_err;
-			}
-		      else if (args(i).is_string ())
-			{
-			  err_name = unique_symbol_name ("__cellfun_fcn_");
-			  std::string fname = "function y = ";
-			  fname.append (fcn_name);
-			  fname.append ("(x) y = ");
-			  error_handler = extract_function (args(i), "cellfun", 
-							    err_name, fname,
-							    "; endfunction");
-			}
-
-		      if (! error_handler)
-			goto cellfun_err;
+              if (arg == "uniformoutput")
+                uniform_output = args(i++).bool_value();
+              else if (arg == "errorhandler")
+                {
+                  if (args(i).is_function_handle () || 
+                      args(i).is_inline_function ())
+                    {
+                      error_handler = args(i++);
+                    }
+                  else if (args(i).is_string ())
+                    {
+                      std::string err_name = args(i++).string_value ();
+                      error_handler = symbol_table::find_function (err_name);
+                      if (error_handler.is_undefined ())
+                        {
+                          error ("cellfun: invalid function name: %s", err_name.c_str ());
+                          goto cellfun_err;
+                        }
+                    }
+                  else
+                    {
+                      error ("invalid errorhandler value");
+                      goto cellfun_err;
+                    }
+                }
+              else
+                {
+                  error ("cellfun: unrecognized parameter %s", 
+                         arg.c_str());
+                  goto cellfun_err;
+                }
 
-		      have_error_handler = true;
-		      i++;
-		    }
-		  else
-		    {
-		      error ("cellfun: unrecognized parameter %s", 
-			     arg.c_str());
-		      goto cellfun_err;
-		    }
-		  offset += 2;
-		}
-	      else
-		{
-		  inputs[i-offset] = args(i).cell_value ();
-		  if (f_args.dims() != inputs[i-offset].dims())
-		    {
-		      error ("cellfun: Dimension mismatch");
-		      goto cellfun_err;
+              offset += 2;
+            }
+          else
+            {
+              inputs[i-offset] = args(i).cell_value ();
+              if (f_args.dims() != inputs[i-offset].dims())
+                {
+                  error ("cellfun: Dimension mismatch");
+                  goto cellfun_err;
+
+                }
+              i++;
+            }
+        }
 
-		    }
-		  i++;
-		}
-	    }
+      nargin -= offset;
+      inputlist.resize(nargin);
 
-          nargin -= offset;
-	  inputlist.resize(nargin);
-
-	  if (have_error_handler)
-	    buffer_error_messages++;
+      if (error_handler.is_defined ())
+        buffer_error_messages++;
 
-	  if (uniform_output)
-	    {
-              OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout);
+      if (uniform_output)
+        {
+          OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout);
 
-	      for (octave_idx_type count = 0; count < k ; count++)
-		{
-		  for (int j = 0; j < nargin; j++)
-		    inputlist(j) = cinputs[j](count);
+          for (octave_idx_type count = 0; count < k ; count++)
+            {
+              for (int j = 0; j < nargin; j++)
+                inputlist(j) = cinputs[j](count);
 
-		  octave_value_list tmp = feval (func, inputlist, nargout);
+              octave_value_list tmp = func.do_multi_index_op (nargout, inputlist);
 
-		  if (error_state && have_error_handler)
-		    {
-		      Octave_map msg;
-		      msg.assign ("identifier", last_error_id ());
-		      msg.assign ("message", last_error_message ());
-		      msg.assign ("index", octave_value(double (count + static_cast<octave_idx_type>(1))));
-		      octave_value_list errlist = inputlist;
-		      errlist.prepend (msg);
-		      buffer_error_messages--;
-		      error_state = 0;
-		      tmp = feval (error_handler, errlist, nargout);
-		      buffer_error_messages++;
+              if (error_state && error_handler.is_defined ())
+                {
+                  Octave_map msg;
+                  msg.assign ("identifier", last_error_id ());
+                  msg.assign ("message", last_error_message ());
+                  msg.assign ("index", octave_value(double (count + static_cast<octave_idx_type>(1))));
+                  octave_value_list errlist = inputlist;
+                  errlist.prepend (msg);
+                  buffer_error_messages--;
+                  error_state = 0;
+                  tmp = error_handler.do_multi_index_op (nargout, errlist);
+                  buffer_error_messages++;
 
-		      if (error_state)
-			goto cellfun_err;
-		    }
+                  if (error_state)
+                    goto cellfun_err;
+                }
 
-		  if (tmp.length() < nargout)
-		    {
-		      error ("cellfun: too many output arguments");
-		      goto cellfun_err;
-		    }
+              if (tmp.length() < nargout)
+                {
+                  error ("cellfun: too many output arguments");
+                  goto cellfun_err;
+                }
 
-		  if (error_state)
-		    break;
+              if (error_state)
+                break;
 
-		  if (count == 0)
-		    {
-		      for (int j = 0; j < nargout; j++)
-			{
-			  octave_value val = tmp(j);
+              if (count == 0)
+                {
+                  for (int j = 0; j < nargout; j++)
+                    {
+                      octave_value val = tmp(j);
 
-                          if (val.numel () == 1)
-                            retptr[j].reset (make_col_helper (val, f_args.dims ()));
-                          else
-                            {
-                              error ("cellfun: expecting all values to be scalars for UniformOutput = true");
-                              break;
-                            }
-			}
-		    }
-		  else
-		    {
-		      for (int j = 0; j < nargout; j++)
-			{
-			  octave_value val = tmp(j);
+                      if (val.numel () == 1)
+                        retptr[j].reset (make_col_helper (val, f_args.dims ()));
+                      else
+                        {
+                          error ("cellfun: expecting all values to be scalars for UniformOutput = true");
+                          break;
+                        }
+                    }
+                }
+              else
+                {
+                  for (int j = 0; j < nargout; j++)
+                    {
+                      octave_value val = tmp(j);
 
-                          if (! retptr[j]->collect (count, val))
-                            {
-                              // FIXME: A more elaborate structure would allow again a virtual
-                              // constructor here.
-                              retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), 
-                                                                          f_args.dims ()));
-                              retptr[j]->collect (count, val);
-                            }
+                      if (! retptr[j]->collect (count, val))
+                        {
+                          // FIXME: A more elaborate structure would allow again a virtual
+                          // constructor here.
+                          retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), 
+                                                                      f_args.dims ()));
+                          retptr[j]->collect (count, val);
                         }
-		    }
+                    }
+                }
 
-		  if (error_state)
-		    break;
-		}
+              if (error_state)
+                break;
+            }
 
-              retval.resize (nargout);
-              for (int j = 0; j < nargout; j++)
-                {
-                  if (retptr[j].get ())
-                    retval(j) = retptr[j]->result ();
-                  else
-                    retval(j) = Matrix ();
-                }
-	    }
-	  else
-	    {
-	      OCTAVE_LOCAL_BUFFER (Cell, results, nargout);
-	      for (int j = 0; j < nargout; j++)
-		results[j].resize(f_args.dims());
+          retval.resize (nargout);
+          for (int j = 0; j < nargout; j++)
+            {
+              if (retptr[j].get ())
+                retval(j) = retptr[j]->result ();
+              else
+                retval(j) = Matrix ();
+            }
+        }
+      else
+        {
+          OCTAVE_LOCAL_BUFFER (Cell, results, nargout);
+          for (int j = 0; j < nargout; j++)
+            results[j].resize(f_args.dims());
 
-	      for (octave_idx_type count = 0; count < k ; count++)
-		{
-		  for (int j = 0; j < nargin; j++)
-		    inputlist(j) = cinputs[j](count);
+          for (octave_idx_type count = 0; count < k ; count++)
+            {
+              for (int j = 0; j < nargin; j++)
+                inputlist(j) = cinputs[j](count);
 
-		  octave_value_list tmp = feval (func, inputlist, nargout);
+              octave_value_list tmp = func.do_multi_index_op (nargout, inputlist);
 
-		  if (error_state && have_error_handler)
-		    {
-		      Octave_map msg;
-		      msg.assign ("identifier", last_error_id ());
-		      msg.assign ("message", last_error_message ());
-		      msg.assign ("index", octave_value(double (count + static_cast<octave_idx_type>(1))));
-		      octave_value_list errlist = inputlist;
-		      errlist.prepend (msg);
-		      buffer_error_messages--;
-		      error_state = 0;
-		      tmp = feval (error_handler, errlist, nargout);
-		      buffer_error_messages++;
+              if (error_state && error_handler.is_defined ())
+                {
+                  Octave_map msg;
+                  msg.assign ("identifier", last_error_id ());
+                  msg.assign ("message", last_error_message ());
+                  msg.assign ("index", octave_value(double (count + static_cast<octave_idx_type>(1))));
+                  octave_value_list errlist = inputlist;
+                  errlist.prepend (msg);
+                  buffer_error_messages--;
+                  error_state = 0;
+                  tmp = error_handler.do_multi_index_op (nargout, errlist);
+                  buffer_error_messages++;
 
-		      if (error_state)
-			goto cellfun_err;
-		    }
+                  if (error_state)
+                    goto cellfun_err;
+                }
 
-		  if (tmp.length() < nargout)
-		    {
-		      error ("cellfun: too many output arguments");
-		      goto cellfun_err;
-		    }
+              if (tmp.length() < nargout)
+                {
+                  error ("cellfun: too many output arguments");
+                  goto cellfun_err;
+                }
 
-		  if (error_state)
-		    break;
+              if (error_state)
+                break;
 
 
-		  for (int j = 0; j < nargout; j++)
-		    results[j](count) = tmp(j);
-		}
-
-	      retval.resize(nargout);
-	      for (int j = 0; j < nargout; j++)
-		retval(j) = results[j];
-	    }
+              for (int j = 0; j < nargout; j++)
+                results[j](count) = tmp(j);
+            }
 
-	cellfun_err:
-	  if (error_state)
-	    retval = octave_value_list();
+          retval.resize(nargout);
+          for (int j = 0; j < nargout; j++)
+            retval(j) = results[j];
+        }
 
-	  if (! fcn_name.empty ())
-	    clear_function (fcn_name);
-
-	  if (! err_name.empty ())
-	    clear_function (err_name);
-	}
+cellfun_err:
+      if (error_state)
+        retval = octave_value_list();
 
       unwind_protect::run_frame (uwp_frame);
     }
+  else
+    error ("cellfun: first argument must be a string or function handle");
 
   return retval;
 }