diff src/DLD-FUNCTIONS/bsxfun.cc @ 9743:26abff55f6fe

optimize bsxfun for common built-in operations
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 20 Oct 2009 10:47:22 +0200
parents cf714e75c656
children 119d97db51f0
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/bsxfun.cc	Mon Oct 19 19:17:49 2009 -0700
+++ b/src/DLD-FUNCTIONS/bsxfun.cc	Tue Oct 20 10:47:22 2009 +0200
@@ -36,6 +36,154 @@
 #include "variables.h"
 #include "ov-colon.h"
 #include "unwind-prot.h"
+#include "ov-fcn-handle.h"
+
+// Optimized bsxfun operations
+enum bsxfun_builtin_op
+{
+  bsxfun_builtin_plus = 0,
+  bsxfun_builtin_minus,
+  bsxfun_builtin_times,
+  bsxfun_builtin_divide,
+  bsxfun_builtin_max,
+  bsxfun_builtin_min,
+  bsxfun_builtin_eq,
+  bsxfun_builtin_ne,
+  bsxfun_builtin_lt,
+  bsxfun_builtin_le,
+  bsxfun_builtin_gt,
+  bsxfun_builtin_ge,
+  bsxfun_builtin_and,
+  bsxfun_builtin_or,
+  bsxfun_builtin_unknown,
+  bsxfun_num_builtin_ops = bsxfun_builtin_unknown
+};
+
+const char *bsxfun_builtin_names[] = 
+{
+  "plus",
+  "minus",
+  "times",
+  "rdivide",
+  "max",
+  "min",
+  "eq",
+  "ne",
+  "lt",
+  "le",
+  "gt",
+  "ge",
+  "and",
+  "or"
+};
+
+static bsxfun_builtin_op 
+bsxfun_builtin_lookup (const std::string& name)
+{
+  for (int i = 0; i < bsxfun_num_builtin_ops; i++)
+    if (name == bsxfun_builtin_names[i])
+      return static_cast<bsxfun_builtin_op> (i);
+  return bsxfun_builtin_unknown;
+}
+
+typedef octave_value (*bsxfun_handler) (const octave_value&, const octave_value&);
+
+// Static table of handlers.
+bsxfun_handler bsxfun_handler_table[bsxfun_num_builtin_ops][btyp_num_types];
+
+template <class NDA, NDA (bsxfun_op) (const NDA&, const NDA&)> 
+static octave_value 
+bsxfun_forward_op (const octave_value& x, const octave_value& y)
+{
+  NDA xa = octave_value_extract<NDA> (x);
+  NDA ya = octave_value_extract<NDA> (y);
+  return octave_value (bsxfun_op (xa, ya));
+}
+
+template <class NDA, boolNDArray (bsxfun_rel) (const NDA&, const NDA&)> 
+static octave_value 
+bsxfun_forward_rel (const octave_value& x, const octave_value& y)
+{
+  NDA xa = octave_value_extract<NDA> (x);
+  NDA ya = octave_value_extract<NDA> (y);
+  return octave_value (bsxfun_rel (xa, ya));
+}
+
+static void maybe_fill_table (void)
+{
+  static bool filled = false;
+  if (filled)
+    return;
+
+#define REGISTER_OP_HANDLER(OP, BTYP, NDA, FUNOP) \
+  bsxfun_handler_table[OP][BTYP] = bsxfun_forward_op<NDA, FUNOP>
+#define REGISTER_REL_HANDLER(REL, BTYP, NDA, FUNREL) \
+  bsxfun_handler_table[REL][BTYP] = bsxfun_forward_rel<NDA, FUNREL>
+#define REGISTER_STD_HANDLERS(BTYP, NDA) \
+  REGISTER_OP_HANDLER (bsxfun_builtin_plus, BTYP, NDA, bsxfun_add); \
+  REGISTER_OP_HANDLER (bsxfun_builtin_minus, BTYP, NDA, bsxfun_sub); \
+  REGISTER_OP_HANDLER (bsxfun_builtin_times, BTYP, NDA, bsxfun_mul); \
+  REGISTER_OP_HANDLER (bsxfun_builtin_divide, BTYP, NDA, bsxfun_div); \
+  REGISTER_OP_HANDLER (bsxfun_builtin_max, BTYP, NDA, bsxfun_max); \
+  REGISTER_OP_HANDLER (bsxfun_builtin_min, BTYP, NDA, bsxfun_min); \
+  REGISTER_REL_HANDLER (bsxfun_builtin_eq, BTYP, NDA, bsxfun_eq); \
+  REGISTER_REL_HANDLER (bsxfun_builtin_ne, BTYP, NDA, bsxfun_ne); \
+  REGISTER_REL_HANDLER (bsxfun_builtin_lt, BTYP, NDA, bsxfun_lt); \
+  REGISTER_REL_HANDLER (bsxfun_builtin_le, BTYP, NDA, bsxfun_le); \
+  REGISTER_REL_HANDLER (bsxfun_builtin_gt, BTYP, NDA, bsxfun_gt); \
+  REGISTER_REL_HANDLER (bsxfun_builtin_ge, BTYP, NDA, bsxfun_ge)
+
+  REGISTER_STD_HANDLERS (btyp_double, NDArray);
+  REGISTER_STD_HANDLERS (btyp_float, FloatNDArray);
+  REGISTER_STD_HANDLERS (btyp_complex, ComplexNDArray);
+  REGISTER_STD_HANDLERS (btyp_float_complex, FloatComplexNDArray);
+  REGISTER_STD_HANDLERS (btyp_int8,  int8NDArray);
+  REGISTER_STD_HANDLERS (btyp_int16, int16NDArray);
+  REGISTER_STD_HANDLERS (btyp_int32, int32NDArray);
+  REGISTER_STD_HANDLERS (btyp_int64, int64NDArray);
+  REGISTER_STD_HANDLERS (btyp_uint8,  uint8NDArray);
+  REGISTER_STD_HANDLERS (btyp_uint16, uint16NDArray);
+  REGISTER_STD_HANDLERS (btyp_uint32, uint32NDArray);
+  REGISTER_STD_HANDLERS (btyp_uint64, uint64NDArray);
+
+  // For bools, we register and/or.
+  REGISTER_OP_HANDLER (bsxfun_builtin_and, btyp_bool, boolNDArray, bsxfun_and);
+  REGISTER_OP_HANDLER (bsxfun_builtin_or, btyp_bool, boolNDArray, bsxfun_or);
+}
+
+static octave_value
+maybe_optimized_builtin (const std::string& name,
+                         const octave_value& a, const octave_value& b)
+{
+  octave_value retval;
+
+  maybe_fill_table ();
+
+  bsxfun_builtin_op op = bsxfun_builtin_lookup (name);
+  if (op != bsxfun_builtin_unknown)
+    {
+      builtin_type_t btyp_a = a.builtin_type (), btyp_b = b.builtin_type ();
+
+      // Simplify single/double combinations.
+      if (btyp_a == btyp_float && btyp_b == btyp_double)
+        btyp_b = btyp_float;
+      else if (btyp_a == btyp_double && btyp_b == btyp_float)
+        btyp_a = btyp_float;
+      else if (btyp_a == btyp_float_complex && btyp_b == btyp_complex)
+        btyp_b = btyp_float_complex;
+      else if (btyp_a == btyp_complex && btyp_b == btyp_float_complex)
+        btyp_a = btyp_float_complex;
+
+      if (btyp_a == btyp_b && btyp_a != btyp_unknown)
+        {
+          bsxfun_handler handler = bsxfun_handler_table[op][btyp_a];
+          if (handler)
+            retval = handler (a, b);
+        }
+    }
+
+  return retval;
+}
 
 static bool
 maybe_update_column (octave_value& Ac, const octave_value& A, 
@@ -160,12 +308,27 @@
       else if (! (args(0).is_function_handle () || args(0).is_inline_function ()))
         error ("bsxfun: first argument must be a string or function handle");
 
-      if (! error_state)
+      const octave_value A = args (1);
+      const octave_value B = args (2);
+
+      if (func.is_builtin_function () 
+          || (func.is_function_handle () 
+              && ! func.fcn_handle_value ()->is_overloaded () 
+              && ! A.is_object () && ! B.is_object ()))
+        {
+          octave_function *fcn_val = func.function_value ();
+          if (fcn_val)
+            {
+              octave_value tmp = maybe_optimized_builtin (fcn_val->name (), A, B);
+              if (tmp.is_defined ())
+                retval(0) = tmp;
+            }
+        }
+
+      if (! error_state && retval.empty ())
 	{
-	  const octave_value A = args (1);
 	  dim_vector dva = A.dims ();
 	  octave_idx_type nda = dva.length ();
-	  const octave_value B = args (2);
 	  dim_vector dvb = B.dims ();
 	  octave_idx_type ndb = dvb.length ();
 	  octave_idx_type nd = nda;