changeset 19008:80ca3b05d77c draft

New "dispatch" selects template argument from octave-value (Bug #42424, 42425) * find.cc (Ffind): This method now calls dispatch() rather than attempting to handle all matrix types on its own (findTemplated): Changed to a functor to be passed as a template template argument to dispatch() (findInfo): A struct that holds the other arguments to find (n_to_find, direction, nargout) Added unit tests for bugs 42424 and 42425 * (new file) dispatch.h (dispatch): A method for dispatching function calls to the right templated value based on an octave_value argument.
author David Spies <dnspies@gmail.com>
date Sat, 21 Jun 2014 13:13:05 -0600
parents 2e0613dadfee
children 8d47ce2053f2
files libinterp/corefcn/dispatch.h libinterp/corefcn/find.cc libinterp/corefcn/module.mk
diffstat 3 files changed, 245 insertions(+), 123 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/libinterp/corefcn/dispatch.h	Sat Jun 21 13:13:05 2014 -0600
@@ -0,0 +1,189 @@
+/*
+
+Copyright (C) 2014 David Spies
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+
+#if !defined (octave_dispatch_h)
+#define octave_dispatch_h 1
+
+#include <string>
+
+#include "gripes.h"
+#include "oct-obj.h"
+
+// This function takes a templated functor as a template-template argument
+// and calls it back with the matrix-type corresponding to arg.
+//
+// This is handy when you wish to write a templated function for dealing
+// with many different matrix types and you don't want to explicitly have
+// to list out all the different types your function can deal with and
+// how to call it with each one.
+//
+// It is expected that the functor operator() takes two arguments:
+//
+// 1. The value being unwrapped from arg
+// and
+// 2. A struct containing any other information the function needs
+// (ie other arguments, nargin, nargout etc.)
+//
+// The return value is an octave_value_list containing the return value
+// for the function
+//
+// In addition to the above two parameters, the dispatch function itself
+// also takes a string argument which is the name of the function being
+// called.  This argument is only used for error-reporting
+//
+// It is highly recommended you always use this function as you can be
+// sure dispatch will throw a compiler exception if you forget to handle
+// a particular type.
+//
+// For an example of how to call dispatch, see the "find" function in
+// libinterp/corefcn/find.cc
+template<template<typename > class fun, typename Inf>
+octave_value_list
+dispatch (const octave_value& arg, const Inf& info, const std::string& funname)
+{
+  octave_value_list retval;
+
+  if (arg.is_bool_type ())
+    {
+      if (arg.is_sparse_type ())
+        {
+          SparseBoolMatrix v = arg.sparse_bool_matrix_value ();
+
+          if (! error_state)
+            retval = fun<SparseBoolMatrix> () (v, info);
+        }
+      else
+        {
+          boolNDArray v = arg.bool_array_value ();
+
+          if (! error_state)
+            retval = fun<boolNDArray> () (v, info);
+        }
+    }
+  else if (arg.is_integer_type ())
+    {
+#define DO_INT_BRANCH(INTT) \
+      if (arg.is_ ## INTT ## _type ()) \
+        { \
+          INTT ## NDArray v = arg.INTT ## _array_value (); \
+          \
+          if (! error_state) \
+            retval = fun<INTT ## NDArray> () (v, info);\
+        } else
+
+      DO_INT_BRANCH (int8)
+      DO_INT_BRANCH (int16)
+      DO_INT_BRANCH (int32)
+      DO_INT_BRANCH (int64)
+      DO_INT_BRANCH (uint8)
+      DO_INT_BRANCH (uint16)
+      DO_INT_BRANCH (uint32)
+      DO_INT_BRANCH (uint64)
+        panic_impossible ();
+#undef DO_INT_BRANCH
+    }
+  else if (arg.is_sparse_type ())
+    {
+      if (arg.is_real_type ())
+        {
+          SparseMatrix v = arg.sparse_matrix_value ();
+
+          if (! error_state)
+            retval = fun<SparseMatrix> () (v, info);
+        }
+      else if (arg.is_complex_type ())
+        {
+          SparseComplexMatrix v = arg.sparse_complex_matrix_value ();
+
+          if (! error_state)
+            retval = fun<SparseComplexMatrix> () (v, info);
+        }
+      else
+        gripe_wrong_type_arg (funname, arg);
+    }
+  else if (arg.is_diag_matrix ())
+    {
+      if (arg.is_real_type ())
+        {
+          DiagMatrix v = arg.diag_matrix_value ();
+          if (! error_state)
+            retval = fun<DiagMatrix> () (v, info);
+        }
+      else if (arg.is_complex_type ())
+        {
+          ComplexDiagMatrix v = arg.complex_diag_matrix_value ();
+          if (! error_state)
+            retval = fun<ComplexDiagMatrix> () (v, info);
+        }
+    }
+  else if (arg.is_perm_matrix ())
+    {
+      PermMatrix v = arg.perm_matrix_value ();
+
+      if (! error_state)
+        retval = fun<PermMatrix> () (v, info);
+    }
+  else if (arg.is_string ())
+    {
+      charNDArray v = arg.char_array_value ();
+
+      if (! error_state)
+        retval = fun<charNDArray> () (v, info);
+    }
+  else if (arg.is_single_type ())
+    {
+      if (arg.is_real_type ())
+        {
+          FloatNDArray v = arg.float_array_value ();
+
+          if (! error_state)
+            retval = fun<FloatNDArray> () (v, info);
+        }
+      else if (arg.is_complex_type ())
+        {
+          FloatComplexNDArray v = arg.float_complex_array_value ();
+
+          if (! error_state)
+            retval = fun<FloatComplexNDArray> () (v, info);
+        }
+    }
+  else if (arg.is_real_type ())
+    {
+      NDArray v = arg.array_value ();
+
+      if (! error_state)
+        retval = fun<NDArray> () (v, info);
+    }
+  else if (arg.is_complex_type ())
+    {
+      ComplexNDArray v = arg.complex_array_value ();
+
+      if (! error_state)
+        retval = fun<ComplexNDArray> () (v, info);
+    }
+  else
+    gripe_wrong_type_arg (funname, arg);
+
+  return retval;
+}
+
+#endif
--- a/libinterp/corefcn/find.cc	Tue Jun 17 16:41:11 2014 -0600
+++ b/libinterp/corefcn/find.cc	Sat Jun 21 13:13:05 2014 -0600
@@ -25,6 +25,7 @@
 #include <config.h>
 #endif
 
+#include "dispatch.h"
 #include "find.h"
 
 #include "defun.h"
@@ -34,6 +35,10 @@
 
 namespace find
 {
+  // The find function should be seen as the canonical example demonstrating
+  // how to properly call dispatch.h
+  // It should always behave properly for all matrix types.
+  //
   // ffind_result is a generic type used for storing the result of
   // a find operation.  The way in which this result is stored will
   // vary based on whether the number of requested return values
@@ -181,13 +186,26 @@
     return octave_value_list ();
   }
 
+  struct find_info
+  {
+    octave_idx_type n_to_find;
+    direction dir;
+    int nargout;
+  };
+
+  // This functor will be called by dispatch.h with the proper type M.
+  // This avoids having to explicitly list the different types find can
+  // handle and instead delegates that duty to the generic "dispatch"
+  // function.
   template<typename M>
-  octave_value_list
-  find_templated (const M& v, int nargout, octave_idx_type n_to_find,
-                  direction dir)
+  struct find_templated
   {
-    return nargout_to_template (v, nargout, n_to_find, dir);
-  }
+    octave_value_list
+    operator() (const M& v, const find_info& inf)
+    {
+      return nargout_to_template (v, inf.nargout, inf.n_to_find, inf.dir);
+    }
+  };
 
 }
 
@@ -255,6 +273,15 @@
 @seealso{nonzeros}\n\
 @end deftypefn")
 {
+  find::find_info inf;
+
+  if(nargout < 1)
+    nargout = 1;
+  else if(nargout > 3)
+    nargout = 3;
+
+  inf.nargout = nargout;
+
   octave_value_list retval;
 
   int nargin = args.length ();
@@ -272,7 +299,7 @@
     nargout = 3;
 
   // Setup the default options.
-  octave_idx_type n_to_find = -1;
+  inf.n_to_find = -1;
   if (nargin > 1)
     {
       double val = args(1).scalar_value ();
@@ -283,11 +310,11 @@
           return retval;
         }
       else if (! xisinf (val))
-        n_to_find = val;
+        inf.n_to_find = val;
     }
 
   // Direction to do the searching.
-  direction dir = FORWARD;
+  inf.dir = FORWARD;
   if (nargin > 2)
     {
       std::string s_arg = args(2).string_value ();
@@ -298,9 +325,9 @@
           return retval;
         }
       if (s_arg == "first")
-        dir = FORWARD;
+        inf.dir = FORWARD;
       else if (s_arg == "last")
-        dir = BACKWARD;
+        inf.dir = BACKWARD;
       else
         {
           error ("find: DIRECTION must be \"first\" or \"last\"");
@@ -308,123 +335,22 @@
         }
     }
 
-  octave_value arg = args(0);
-
-  if (arg.is_bool_type ())
-    {
-      if (arg.is_sparse_type ())
-        {
-          SparseBoolMatrix v = arg.sparse_bool_matrix_value ();
-
-          if (! error_state)
-            retval = find::find_templated (v, nargout, n_to_find, dir);
-        }
-      else if (nargout <= 1 && n_to_find == -1)
-        {
-          // This case is equivalent to extracting indices from a logical
-          // matrix. Try to reuse the possibly cached index vector.
-          retval(0) = arg.index_vector ().unmask ();
-        }
-      else
-        {
-          boolNDArray v = arg.bool_array_value ();
-
-          if (! error_state)
-            retval = find::find_templated (v, nargout, n_to_find, dir);
-        }
-    }
-  else if (arg.is_integer_type ())
-    {
-#define DO_INT_BRANCH(INTT)                                               \
-      else if (arg.is_ ## INTT ## _type ())                               \
-        {                                                                 \
-          INTT ## NDArray v = arg.INTT ## _array_value ();                \
-                                                                          \
-            if (! error_state)                                            \
-              retval = find::find_templated (v, nargout, n_to_find, dir); \
-        }
-
-      if (false)
-        ;
-      DO_INT_BRANCH (int8)
-      DO_INT_BRANCH (int16)
-      DO_INT_BRANCH (int32)
-      DO_INT_BRANCH (int64)
-      DO_INT_BRANCH (uint8)
-      DO_INT_BRANCH (uint16)
-      DO_INT_BRANCH (uint32)
-      DO_INT_BRANCH (uint64)
-      else
-        panic_impossible ();
-    }
-  else if (arg.is_sparse_type ())
-    {
-      if (arg.is_real_type ())
-        {
-          SparseMatrix v = arg.sparse_matrix_value ();
+  const octave_value& arg = args(0);
 
-          if (! error_state)
-            retval = find::find_templated (v, nargout, n_to_find, dir);
-        }
-      else if (arg.is_complex_type ())
-        {
-          SparseComplexMatrix v = arg.sparse_complex_matrix_value ();
-
-          if (! error_state)
-            retval = find::find_templated (v, nargout, n_to_find, dir);
-        }
-      else
-        gripe_wrong_type_arg ("find", arg);
-    }
-  else if (arg.is_perm_matrix ())
-    {
-      PermMatrix P = arg.perm_matrix_value ();
-
-      if (! error_state)
-        retval = find::find_templated (P, nargout, n_to_find, dir);
-    }
-  else if (arg.is_string ())
-    {
-      charNDArray chnda = arg.char_array_value ();
-
-      if (! error_state)
-        retval = find::find_templated (chnda, nargout, n_to_find, dir);
-    }
-  else if (arg.is_single_type ())
+  //For this special case, it's unnecessary to call dispatch because
+  //we already know the types of everything
+  if (arg.is_bool_type() && inf.nargout <= 1 && inf.n_to_find == -1)
     {
-      if (arg.is_real_type ())
-        {
-          FloatNDArray nda = arg.float_array_value ();
-
-          if (! error_state)
-            retval = find::find_templated (nda, nargout, n_to_find, dir);
-        }
-      else if (arg.is_complex_type ())
-        {
-          FloatComplexNDArray cnda = arg.float_complex_array_value ();
-
-          if (! error_state)
-            retval = find::find_templated (cnda, nargout, n_to_find, dir);
-        }
+      // This case is equivalent to extracting indices from a logical
+      // matrix. Try to reuse the possibly cached index vector.
+      retval(0) = arg.index_vector ().unmask ();
+      return retval;
     }
-  else if (arg.is_real_type ())
-    {
-      NDArray nda = arg.array_value ();
 
-      if (! error_state)
-        retval = find::find_templated (nda, nargout, n_to_find, dir);
-    }
-  else if (arg.is_complex_type ())
-    {
-      ComplexNDArray cnda = arg.complex_array_value ();
-
-      if (! error_state)
-        retval = find::find_templated (cnda, nargout, n_to_find, dir);
-    }
-  else
-    gripe_wrong_type_arg ("find", arg);
-
-  return retval;
+  //Dispatches a call to the proper instantiation of the findTemplated
+  //functor.  This allows us to use the type of "arg" as a template
+  //argument to the find_to_iter function.
+  return dispatch<find::find_templated> (arg, inf, "find");
 }
 
 /*
@@ -481,5 +407,11 @@
 %! i = find(x);
 %! assert (i == 3e09);
 
+%!test
+%! fail("[a,b,c,d,e,f] = find(speye(3));")
+
+%!test
+%! [i,j] = find(eye(1000000));
+
 %!error find ()
 */
--- a/libinterp/corefcn/module.mk	Tue Jun 17 16:41:11 2014 -0600
+++ b/libinterp/corefcn/module.mk	Sat Jun 21 13:13:05 2014 -0600
@@ -53,6 +53,7 @@
   corefcn/defun-int.h \
   corefcn/defun.h \
   corefcn/dirfns.h \
+  corefcn/dispatch.h \
   corefcn/display.h \
   corefcn/dynamic-ld.h \
   corefcn/error.h \