Mercurial > octave
diff src/DLD-FUNCTIONS/bsxfun.cc @ 9743:26abff55f6fe
optimize bsxfun for common built-in operations
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 20 Oct 2009 10:47:22 +0200 |
parents | cf714e75c656 |
children | 119d97db51f0 |
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/bsxfun.cc Mon Oct 19 19:17:49 2009 -0700 +++ b/src/DLD-FUNCTIONS/bsxfun.cc Tue Oct 20 10:47:22 2009 +0200 @@ -36,6 +36,154 @@ #include "variables.h" #include "ov-colon.h" #include "unwind-prot.h" +#include "ov-fcn-handle.h" + +// Optimized bsxfun operations +enum bsxfun_builtin_op +{ + bsxfun_builtin_plus = 0, + bsxfun_builtin_minus, + bsxfun_builtin_times, + bsxfun_builtin_divide, + bsxfun_builtin_max, + bsxfun_builtin_min, + bsxfun_builtin_eq, + bsxfun_builtin_ne, + bsxfun_builtin_lt, + bsxfun_builtin_le, + bsxfun_builtin_gt, + bsxfun_builtin_ge, + bsxfun_builtin_and, + bsxfun_builtin_or, + bsxfun_builtin_unknown, + bsxfun_num_builtin_ops = bsxfun_builtin_unknown +}; + +const char *bsxfun_builtin_names[] = +{ + "plus", + "minus", + "times", + "rdivide", + "max", + "min", + "eq", + "ne", + "lt", + "le", + "gt", + "ge", + "and", + "or" +}; + +static bsxfun_builtin_op +bsxfun_builtin_lookup (const std::string& name) +{ + for (int i = 0; i < bsxfun_num_builtin_ops; i++) + if (name == bsxfun_builtin_names[i]) + return static_cast<bsxfun_builtin_op> (i); + return bsxfun_builtin_unknown; +} + +typedef octave_value (*bsxfun_handler) (const octave_value&, const octave_value&); + +// Static table of handlers. +bsxfun_handler bsxfun_handler_table[bsxfun_num_builtin_ops][btyp_num_types]; + +template <class NDA, NDA (bsxfun_op) (const NDA&, const NDA&)> +static octave_value +bsxfun_forward_op (const octave_value& x, const octave_value& y) +{ + NDA xa = octave_value_extract<NDA> (x); + NDA ya = octave_value_extract<NDA> (y); + return octave_value (bsxfun_op (xa, ya)); +} + +template <class NDA, boolNDArray (bsxfun_rel) (const NDA&, const NDA&)> +static octave_value +bsxfun_forward_rel (const octave_value& x, const octave_value& y) +{ + NDA xa = octave_value_extract<NDA> (x); + NDA ya = octave_value_extract<NDA> (y); + return octave_value (bsxfun_rel (xa, ya)); +} + +static void maybe_fill_table (void) +{ + static bool filled = false; + if (filled) + return; + +#define REGISTER_OP_HANDLER(OP, BTYP, NDA, FUNOP) \ + bsxfun_handler_table[OP][BTYP] = bsxfun_forward_op<NDA, FUNOP> +#define REGISTER_REL_HANDLER(REL, BTYP, NDA, FUNREL) \ + bsxfun_handler_table[REL][BTYP] = bsxfun_forward_rel<NDA, FUNREL> +#define REGISTER_STD_HANDLERS(BTYP, NDA) \ + REGISTER_OP_HANDLER (bsxfun_builtin_plus, BTYP, NDA, bsxfun_add); \ + REGISTER_OP_HANDLER (bsxfun_builtin_minus, BTYP, NDA, bsxfun_sub); \ + REGISTER_OP_HANDLER (bsxfun_builtin_times, BTYP, NDA, bsxfun_mul); \ + REGISTER_OP_HANDLER (bsxfun_builtin_divide, BTYP, NDA, bsxfun_div); \ + REGISTER_OP_HANDLER (bsxfun_builtin_max, BTYP, NDA, bsxfun_max); \ + REGISTER_OP_HANDLER (bsxfun_builtin_min, BTYP, NDA, bsxfun_min); \ + REGISTER_REL_HANDLER (bsxfun_builtin_eq, BTYP, NDA, bsxfun_eq); \ + REGISTER_REL_HANDLER (bsxfun_builtin_ne, BTYP, NDA, bsxfun_ne); \ + REGISTER_REL_HANDLER (bsxfun_builtin_lt, BTYP, NDA, bsxfun_lt); \ + REGISTER_REL_HANDLER (bsxfun_builtin_le, BTYP, NDA, bsxfun_le); \ + REGISTER_REL_HANDLER (bsxfun_builtin_gt, BTYP, NDA, bsxfun_gt); \ + REGISTER_REL_HANDLER (bsxfun_builtin_ge, BTYP, NDA, bsxfun_ge) + + REGISTER_STD_HANDLERS (btyp_double, NDArray); + REGISTER_STD_HANDLERS (btyp_float, FloatNDArray); + REGISTER_STD_HANDLERS (btyp_complex, ComplexNDArray); + REGISTER_STD_HANDLERS (btyp_float_complex, FloatComplexNDArray); + REGISTER_STD_HANDLERS (btyp_int8, int8NDArray); + REGISTER_STD_HANDLERS (btyp_int16, int16NDArray); + REGISTER_STD_HANDLERS (btyp_int32, int32NDArray); + REGISTER_STD_HANDLERS (btyp_int64, int64NDArray); + REGISTER_STD_HANDLERS (btyp_uint8, uint8NDArray); + REGISTER_STD_HANDLERS (btyp_uint16, uint16NDArray); + REGISTER_STD_HANDLERS (btyp_uint32, uint32NDArray); + REGISTER_STD_HANDLERS (btyp_uint64, uint64NDArray); + + // For bools, we register and/or. + REGISTER_OP_HANDLER (bsxfun_builtin_and, btyp_bool, boolNDArray, bsxfun_and); + REGISTER_OP_HANDLER (bsxfun_builtin_or, btyp_bool, boolNDArray, bsxfun_or); +} + +static octave_value +maybe_optimized_builtin (const std::string& name, + const octave_value& a, const octave_value& b) +{ + octave_value retval; + + maybe_fill_table (); + + bsxfun_builtin_op op = bsxfun_builtin_lookup (name); + if (op != bsxfun_builtin_unknown) + { + builtin_type_t btyp_a = a.builtin_type (), btyp_b = b.builtin_type (); + + // Simplify single/double combinations. + if (btyp_a == btyp_float && btyp_b == btyp_double) + btyp_b = btyp_float; + else if (btyp_a == btyp_double && btyp_b == btyp_float) + btyp_a = btyp_float; + else if (btyp_a == btyp_float_complex && btyp_b == btyp_complex) + btyp_b = btyp_float_complex; + else if (btyp_a == btyp_complex && btyp_b == btyp_float_complex) + btyp_a = btyp_float_complex; + + if (btyp_a == btyp_b && btyp_a != btyp_unknown) + { + bsxfun_handler handler = bsxfun_handler_table[op][btyp_a]; + if (handler) + retval = handler (a, b); + } + } + + return retval; +} static bool maybe_update_column (octave_value& Ac, const octave_value& A, @@ -160,12 +308,27 @@ else if (! (args(0).is_function_handle () || args(0).is_inline_function ())) error ("bsxfun: first argument must be a string or function handle"); - if (! error_state) + const octave_value A = args (1); + const octave_value B = args (2); + + if (func.is_builtin_function () + || (func.is_function_handle () + && ! func.fcn_handle_value ()->is_overloaded () + && ! A.is_object () && ! B.is_object ())) + { + octave_function *fcn_val = func.function_value (); + if (fcn_val) + { + octave_value tmp = maybe_optimized_builtin (fcn_val->name (), A, B); + if (tmp.is_defined ()) + retval(0) = tmp; + } + } + + if (! error_state && retval.empty ()) { - const octave_value A = args (1); dim_vector dva = A.dims (); octave_idx_type nda = dva.length (); - const octave_value B = args (2); dim_vector dvb = B.dims (); octave_idx_type ndb = dvb.length (); octave_idx_type nd = nda;