changeset 17892:5401637c3fa7

Make cellfun obey "UniformOutput" for optimised internal functions (bug #40467) * cellfun.cc (try_cellfun_internal_ops): Templatise the return values of this function. Replace all return types by template params. (Fcellfun): Move the argument check above the attempt to optimise the call for certain functions. Call try_cellfun_internal_ops with appropriate template parameters depending if uniform_output is requested or not. Add tests.
author Jordi Gutiérrez Hermoso <jordigh@octave.org>
date Sat, 09 Nov 2013 21:47:12 -0500
parents 5fbab07c419f
children 38b726ed04c9
files libinterp/corefcn/cellfun.cc
diffstat 1 files changed, 34 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/cellfun.cc	Wed Nov 06 22:23:21 2013 +0100
+++ b/libinterp/corefcn/cellfun.cc	Sat Nov 09 21:47:12 2013 -0500
@@ -112,6 +112,10 @@
   return tmp;
 }
 
+// Templated function because the user can be stubborn enough to request
+// a cell array as an output even in these cases where the output fits
+// in an ordinary array
+template<typename BNDA, typename NDA>
 static octave_value_list
 try_cellfun_internal_ops (const octave_value_list& args, int nargin)
 {
@@ -125,49 +129,49 @@
 
   if (name == "isempty")
     {
-      boolNDArray result (f_args.dims ());
+      BNDA result (f_args.dims ());
       for (octave_idx_type count = 0; count < k; count++)
         result(count) = f_args.elem (count).is_empty ();
       retval(0) = result;
     }
   else if (name == "islogical")
     {
-      boolNDArray result (f_args.dims ());
+      BNDA result (f_args.dims ());
       for (octave_idx_type  count= 0; count < k; count++)
         result(count) = f_args.elem (count).is_bool_type ();
       retval(0) = result;
     }
   else if (name == "isnumeric")
     {
-      boolNDArray result (f_args.dims ());
+      BNDA result (f_args.dims ());
       for (octave_idx_type  count= 0; count < k; count++)
         result(count) = f_args.elem (count).is_numeric_type ();
       retval(0) = result;
     }
   else if (name == "isreal")
     {
-      boolNDArray result (f_args.dims ());
+      BNDA result (f_args.dims ());
       for (octave_idx_type  count= 0; count < k; count++)
         result(count) = f_args.elem (count).is_real_type ();
       retval(0) = result;
     }
   else if (name == "length")
     {
-      NDArray result (f_args.dims ());
+      NDA result (f_args.dims ());
       for (octave_idx_type  count= 0; count < k; count++)
         result(count) = static_cast<double> (f_args.elem (count).length ());
       retval(0) = result;
     }
   else if (name == "ndims")
     {
-      NDArray result (f_args.dims ());
+      NDA result (f_args.dims ());
       for (octave_idx_type count = 0; count < k; count++)
         result(count) = static_cast<double> (f_args.elem (count).ndims ());
       retval(0) = result;
     }
   else if (name == "numel" || name == "prodofsize")
     {
-      NDArray result (f_args.dims ());
+      NDA result (f_args.dims ());
       for (octave_idx_type count = 0; count < k; count++)
         result(count) = static_cast<double> (f_args.elem (count).numel ());
       retval(0) = result;
@@ -183,7 +187,7 @@
 
           if (! error_state)
             {
-              NDArray result (f_args.dims ());
+              NDA result (f_args.dims ());
               for (octave_idx_type count = 0; count < k; count++)
                 {
                   dim_vector dv = f_args.elem (count).dims ();
@@ -203,7 +207,7 @@
       if (nargin == 3)
         {
           std::string class_name = args(2).string_value ();
-          boolNDArray result (f_args.dims ());
+          BNDA result (f_args.dims ());
           for (octave_idx_type count = 0; count < k; count++)
             result(count) = (f_args.elem (count).class_name () == class_name);
 
@@ -427,7 +431,7 @@
 
   if (func.is_string ())
     {
-      retval = try_cellfun_internal_ops (args, nargin);
+      retval = try_cellfun_internal_ops<boolNDArray,NDArray>(args, nargin);
 
       if (error_state || ! retval.empty ())
         return retval;
@@ -464,6 +468,11 @@
       || func.is_function ())
     {
 
+      bool uniform_output = true;
+      octave_value error_handler;
+
+      get_mapper_fun_options (args, nargin, uniform_output, error_handler);
+
       // The following is an optimisation because the symbol table can
       // give a more specific function class, so this can result in
       // fewer polymorphic function calls as the function gets called
@@ -491,7 +500,15 @@
                 //Try first the optimised code path for built-in functions
                 octave_value_list tmp_args = args;
                 tmp_args(0) = name;
-                retval = try_cellfun_internal_ops (tmp_args, nargin);
+
+                if (uniform_output)
+                  retval =
+                    try_cellfun_internal_ops<boolNDArray, NDArray> (tmp_args,
+                                                                    nargin);
+                else
+                  retval =
+                    try_cellfun_internal_ops<Cell, Cell> (tmp_args, nargin);
+
                 if (error_state || ! retval.empty ())
                   return retval;
               }
@@ -504,11 +521,6 @@
       }
     nevermind:
 
-      bool uniform_output = true;
-      octave_value error_handler;
-
-      get_mapper_fun_options (args, nargin, uniform_output, error_handler);
-
       if (error_state)
         return octave_value_list ();
 
@@ -1029,6 +1041,12 @@
 %! assert (b, {"c", "g"});
 %! assert (c, {".d", ".h"});
 
+## Tests for bug #40467
+%!assert (cellfun (@isreal, {1 inf nan []}), [true, true, true, true]);
+%!assert (cellfun (@isreal, {1 inf nan []}, "UniformOutput", false), {true, true, true, true});
+%!assert (cellfun (@iscomplex, {1 inf nan []}), [false, false, false, false]);
+%!assert (cellfun (@iscomplex, {1 inf nan []}, "UniformOutput", false), {false, false, false, false});
+
 %!error cellfun (1)
 %!error cellfun ("isclass", 1)
 %!error cellfun ("size", 1)