changeset 9465:40de4692c860

auto-expanding scalar cells in cellfun
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 28 Jul 2009 13:46:23 +0200
parents e598248a060d
children 2ebd0717c12d
files src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc
diffstat 2 files changed, 95 insertions(+), 61 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Mon Jul 27 14:15:44 2009 +0200
+++ b/src/ChangeLog	Tue Jul 28 13:46:23 2009 +0200
@@ -1,3 +1,8 @@
+2009-07-28  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/cellfun.cc (Fcellfun): Support auto-expanding scalar
+	cells.
+
 2009-07-27  Jaroslav Hajek  <highegg@gmail.com>
 
 	* symtab.cc (symbol_table::fcn_info::fcn_info_rep::xfind,
--- a/src/DLD-FUNCTIONS/cellfun.cc	Mon Jul 27 14:15:44 2009 +0200
+++ b/src/DLD-FUNCTIONS/cellfun.cc	Tue Jul 28 13:46:23 2009 +0200
@@ -228,6 +228,8 @@
 \n\
 Note that the default output argument is an array of the same size as the\n\
 input arguments.\n\
+Input arguments that are singleton (1x1) cells will be automatically expanded\n\
+to the size of the other arguments.\n\
 \n\
 If the parameter 'UniformOutput' is set to true (the default), then the function\n\
 must return a single element which will be concatenated into the\n\
@@ -286,12 +288,12 @@
       return retval;
     }
   
-  const Cell f_args = args(1).cell_value ();
-  
-  octave_idx_type k = f_args.numel ();
-
   if (func.is_string ())
     {
+      const Cell f_args = args(1).cell_value ();
+
+      octave_idx_type k = f_args.numel ();
+
       std::string name = func.string_value ();
       if (name.find_first_of ("(x)") != std::string::npos)
         warning ("cellfun: passing function body as string is no longer supported."
@@ -396,79 +398,100 @@
       unwind_protect::frame_id_t uwp_frame = unwind_protect::begin_frame ();
       unwind_protect::protect_var (buffer_error_messages);
 
-      octave_value_list inputlist;
       bool uniform_output = true;
       octave_value error_handler;
-      int offset = 1;
-      int i = 1;
-      OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin);
-      // This is to prevent copy-on-write.
-      const Cell *cinputs = inputs;
 
-      while (i < nargin)
+      while (nargin > 3 && args(nargin-2).is_string())
         {
-          if (args(i).is_string())
-            {
-              std::string arg = args(i++).string_value();
-              if (i == nargin)
-                {
-                  error ("cellfun: parameter value is missing");
-                  goto cellfun_err;
-                }
+          std::string arg = args(nargin-2).string_value();
 
-              std::transform (arg.begin (), arg.end (), 
-                              arg.begin (), tolower);
+          std::transform (arg.begin (), arg.end (), 
+                          arg.begin (), tolower);
 
-              if (arg == "uniformoutput")
-                uniform_output = args(i++).bool_value();
-              else if (arg == "errorhandler")
+          if (arg == "uniformoutput")
+            uniform_output = args(nargin-1).bool_value();
+          else if (arg == "errorhandler")
+            {
+              if (args(nargin-1).is_function_handle () || 
+                  args(nargin-1).is_inline_function ())
                 {
-                  if (args(i).is_function_handle () || 
-                      args(i).is_inline_function ())
-                    {
-                      error_handler = args(i++);
-                    }
-                  else if (args(i).is_string ())
+                  error_handler = args(nargin-1);
+                }
+              else if (args(nargin-1).is_string ())
+                {
+                  std::string err_name = args(nargin-1).string_value ();
+                  error_handler = symbol_table::find_function (err_name);
+                  if (error_handler.is_undefined ())
                     {
-                      std::string err_name = args(i++).string_value ();
-                      error_handler = symbol_table::find_function (err_name);
-                      if (error_handler.is_undefined ())
-                        {
-                          error ("cellfun: invalid function name: %s", err_name.c_str ());
-                          goto cellfun_err;
-                        }
-                    }
-                  else
-                    {
-                      error ("invalid errorhandler value");
-                      goto cellfun_err;
+                      error ("cellfun: invalid function name: %s", err_name.c_str ());
+                      break;
                     }
                 }
               else
                 {
-                  error ("cellfun: unrecognized parameter %s", 
-                         arg.c_str());
-                  goto cellfun_err;
+                  error ("invalid errorhandler value");
+                  break;
                 }
-
-              offset += 2;
             }
           else
             {
-              inputs[i-offset] = args(i).cell_value ();
-              if (f_args.dims() != inputs[i-offset].dims())
+              error ("cellfun: unrecognized parameter %s", 
+                     arg.c_str());
+              break;
+            }
+
+          nargin -= 2;
+        }
+
+      nargin -= 1;
+
+      octave_value_list inputlist (nargin, octave_value ());
+
+      OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin);
+      OCTAVE_LOCAL_BUFFER (bool, mask, nargin);
+
+      // This is to prevent copy-on-write.
+      const Cell *cinputs = inputs;
+
+      octave_idx_type k;
+
+      dim_vector fdims (1, 1);
+
+      if (error_state)
+        goto cellfun_err;
+
+      for (int j = 0; j < nargin; j++)
+        {
+          if (! args(j+1).is_cell ())
+            {
+              error ("cellfun: arguments must be cells");
+              goto cellfun_err;
+            }
+
+          inputs[j] = args(j+1).cell_value ();
+          mask[j] = inputs[j].numel () != 1;
+          if (! mask[j])
+            inputlist(j) = cinputs[j](0);
+        }
+
+      k = inputs[0].numel ();
+
+      for (int j = 0; j < nargin; j++)
+        {
+          if (mask[j])
+            {
+              fdims = inputs[j].dims ();
+              for (int i = j+1; i < nargin; i++)
                 {
-                  error ("cellfun: Dimension mismatch");
-                  goto cellfun_err;
-
+                  if (mask[i] && inputs[i].dims () != fdims)
+                    {
+                      error ("cellfun: Dimensions mismatch.");
+                      goto cellfun_err;
+                    }
                 }
-              i++;
             }
         }
 
-      nargin -= offset;
-      inputlist.resize(nargin);
-
       if (error_handler.is_defined ())
         buffer_error_messages++;
 
@@ -479,7 +502,10 @@
           for (octave_idx_type count = 0; count < k ; count++)
             {
               for (int j = 0; j < nargin; j++)
-                inputlist(j) = cinputs[j](count);
+                {
+                  if (mask[j])
+                    inputlist(j) = cinputs[j](count);
+                }
 
               octave_value_list tmp = func.do_multi_index_op (nargout, inputlist);
 
@@ -516,7 +542,7 @@
                       octave_value val = tmp(j);
 
                       if (val.numel () == 1)
-                        retptr[j].reset (make_col_helper (val, f_args.dims ()));
+                        retptr[j].reset (make_col_helper (val, fdims));
                       else
                         {
                           error ("cellfun: expecting all values to be scalars for UniformOutput = true");
@@ -535,7 +561,7 @@
                           // FIXME: A more elaborate structure would allow again a virtual
                           // constructor here.
                           retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), 
-                                                                      f_args.dims ()));
+                                                                      fdims));
                           retptr[j]->collect (count, val);
                         }
                     }
@@ -558,12 +584,15 @@
         {
           OCTAVE_LOCAL_BUFFER (Cell, results, nargout);
           for (int j = 0; j < nargout; j++)
-            results[j].resize(f_args.dims());
+            results[j].resize (fdims);
 
           for (octave_idx_type count = 0; count < k ; count++)
             {
               for (int j = 0; j < nargin; j++)
-                inputlist(j) = cinputs[j](count);
+                {
+                  if (mask[j])
+                    inputlist(j) = cinputs[j](count);
+                }
 
               octave_value_list tmp = func.do_multi_index_op (nargout, inputlist);