Mercurial > octave
diff src/corefcn/bsxfun.cc @ 15039:e753177cde93
maint: Move non-dynamically linked functions from DLD-FUNCTIONS/ to corefcn/ directory
* __contourc__.cc, __dispatch__.cc, __lin_interpn__.cc, __pchip_deriv__.cc,
__qp__.cc, balance.cc, besselj.cc, betainc.cc, bsxfun.cc, cellfun.cc,
colloc.cc, conv2.cc, daspk.cc, dasrt.cc, dassl.cc, det.cc, dlmread.cc, dot.cc,
eig.cc, fft.cc, fft2.cc, fftn.cc, filter.cc, find.cc, gammainc.cc, gcd.cc,
getgrent.cc, getpwent.cc, getrusage.cc, givens.cc, hess.cc, hex2num.cc, inv.cc,
kron.cc, lookup.cc, lsode.cc, lu.cc, luinc.cc, matrix_type.cc, max.cc,
md5sum.cc, mgorth.cc, nproc.cc, pinv.cc, quad.cc, quadcc.cc, qz.cc,
rand.cc, rcond.cc, regexp.cc, schur.cc, spparms.cc, sqrtm.cc, str2double.cc,
strfind.cc, sub2ind.cc, svd.cc, syl.cc, time.cc, tril.cc, typecast.cc:
Move functions from DLD-FUNCTIONS/ to corefcn/ directory. Include "defun.h",
not "defun-dld.h". Change docstring to refer to these as "Built-in Functions".
* build-aux/mk-opts.pl: Generate options code with '#include "defun.h"'. Change
option docstrings to refer to these as "Built-in Functions".
* corefcn/module.mk: List of functions to build in corefcn/ dir.
* DLD-FUNCTIONS/config-module.awk: Update to new build system.
* DLD-FUNCTIONS/module-files: Remove functions which are now in corefcn/ directory.
* src/Makefile.am: Update to build "convenience library" in corefcn/. Octave
program now links against all other libraries + corefcn libary.
* src/find-defun-files.sh: Strip $srcdir from filename.
* src/link-deps.mk: Add REGEX and FFTW link dependencies for liboctinterp.
* type.m, which.m: Change failing tests to use 'amd', still a dynamic function,
rather than 'dot', which isn't.
author | Rik <rik@octave.org> |
---|---|
date | Fri, 27 Jul 2012 15:35:00 -0700 |
parents | src/DLD-FUNCTIONS/bsxfun.cc@5ae9f0f77635 |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/corefcn/bsxfun.cc Fri Jul 27 15:35:00 2012 -0700 @@ -0,0 +1,813 @@ +/* + +Copyright (C) 2007-2012 David Bateman +Copyright (C) 2009 VZLU Prague + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#include <string> +#include <vector> +#include <list> + +#include "lo-mappers.h" + +#include "oct-map.h" +#include "defun.h" +#include "parse.h" +#include "variables.h" +#include "ov-colon.h" +#include "unwind-prot.h" +#include "ov-fcn-handle.h" + +// Optimized bsxfun operations +enum bsxfun_builtin_op +{ + bsxfun_builtin_plus = 0, + bsxfun_builtin_minus, + bsxfun_builtin_times, + bsxfun_builtin_divide, + bsxfun_builtin_max, + bsxfun_builtin_min, + bsxfun_builtin_eq, + bsxfun_builtin_ne, + bsxfun_builtin_lt, + bsxfun_builtin_le, + bsxfun_builtin_gt, + bsxfun_builtin_ge, + bsxfun_builtin_and, + bsxfun_builtin_or, + bsxfun_builtin_power, + bsxfun_builtin_unknown, + bsxfun_num_builtin_ops = bsxfun_builtin_unknown +}; + +const char *bsxfun_builtin_names[] = +{ + "plus", + "minus", + "times", + "rdivide", + "max", + "min", + "eq", + "ne", + "lt", + "le", + "gt", + "ge", + "and", + "or", + "power" +}; + +static bsxfun_builtin_op +bsxfun_builtin_lookup (const std::string& name) +{ + for (int i = 0; i < bsxfun_num_builtin_ops; i++) + if (name == bsxfun_builtin_names[i]) + return static_cast<bsxfun_builtin_op> (i); + return bsxfun_builtin_unknown; +} + +typedef octave_value (*bsxfun_handler) (const octave_value&, const octave_value&); + +// Static table of handlers. +bsxfun_handler bsxfun_handler_table[bsxfun_num_builtin_ops][btyp_num_types]; + +template <class NDA, NDA (bsxfun_op) (const NDA&, const NDA&)> +static octave_value +bsxfun_forward_op (const octave_value& x, const octave_value& y) +{ + NDA xa = octave_value_extract<NDA> (x); + NDA ya = octave_value_extract<NDA> (y); + return octave_value (bsxfun_op (xa, ya)); +} + +template <class NDA, boolNDArray (bsxfun_rel) (const NDA&, const NDA&)> +static octave_value +bsxfun_forward_rel (const octave_value& x, const octave_value& y) +{ + NDA xa = octave_value_extract<NDA> (x); + NDA ya = octave_value_extract<NDA> (y); + return octave_value (bsxfun_rel (xa, ya)); +} + +// Pow needs a special handler for reals because of the potentially complex result. +template <class NDA, class CNDA> +static octave_value +do_bsxfun_real_pow (const octave_value& x, const octave_value& y) +{ + NDA xa = octave_value_extract<NDA> (x); + NDA ya = octave_value_extract<NDA> (y); + if (! ya.all_integers () && xa.any_element_is_negative ()) + return octave_value (bsxfun_pow (CNDA (xa), ya)); + else + return octave_value (bsxfun_pow (xa, ya)); +} + +static void maybe_fill_table (void) +{ + static bool filled = false; + if (filled) + return; + +#define REGISTER_OP_HANDLER(OP, BTYP, NDA, FUNOP) \ + bsxfun_handler_table[OP][BTYP] = bsxfun_forward_op<NDA, FUNOP> +#define REGISTER_REL_HANDLER(REL, BTYP, NDA, FUNREL) \ + bsxfun_handler_table[REL][BTYP] = bsxfun_forward_rel<NDA, FUNREL> +#define REGISTER_STD_HANDLERS(BTYP, NDA) \ + REGISTER_OP_HANDLER (bsxfun_builtin_plus, BTYP, NDA, bsxfun_add); \ + REGISTER_OP_HANDLER (bsxfun_builtin_minus, BTYP, NDA, bsxfun_sub); \ + REGISTER_OP_HANDLER (bsxfun_builtin_times, BTYP, NDA, bsxfun_mul); \ + REGISTER_OP_HANDLER (bsxfun_builtin_divide, BTYP, NDA, bsxfun_div); \ + REGISTER_OP_HANDLER (bsxfun_builtin_max, BTYP, NDA, bsxfun_max); \ + REGISTER_OP_HANDLER (bsxfun_builtin_min, BTYP, NDA, bsxfun_min); \ + REGISTER_REL_HANDLER (bsxfun_builtin_eq, BTYP, NDA, bsxfun_eq); \ + REGISTER_REL_HANDLER (bsxfun_builtin_ne, BTYP, NDA, bsxfun_ne); \ + REGISTER_REL_HANDLER (bsxfun_builtin_lt, BTYP, NDA, bsxfun_lt); \ + REGISTER_REL_HANDLER (bsxfun_builtin_le, BTYP, NDA, bsxfun_le); \ + REGISTER_REL_HANDLER (bsxfun_builtin_gt, BTYP, NDA, bsxfun_gt); \ + REGISTER_REL_HANDLER (bsxfun_builtin_ge, BTYP, NDA, bsxfun_ge) + + REGISTER_STD_HANDLERS (btyp_double, NDArray); + REGISTER_STD_HANDLERS (btyp_float, FloatNDArray); + REGISTER_STD_HANDLERS (btyp_complex, ComplexNDArray); + REGISTER_STD_HANDLERS (btyp_float_complex, FloatComplexNDArray); + REGISTER_STD_HANDLERS (btyp_int8, int8NDArray); + REGISTER_STD_HANDLERS (btyp_int16, int16NDArray); + REGISTER_STD_HANDLERS (btyp_int32, int32NDArray); + REGISTER_STD_HANDLERS (btyp_int64, int64NDArray); + REGISTER_STD_HANDLERS (btyp_uint8, uint8NDArray); + REGISTER_STD_HANDLERS (btyp_uint16, uint16NDArray); + REGISTER_STD_HANDLERS (btyp_uint32, uint32NDArray); + REGISTER_STD_HANDLERS (btyp_uint64, uint64NDArray); + + // For bools, we register and/or. + REGISTER_OP_HANDLER (bsxfun_builtin_and, btyp_bool, boolNDArray, bsxfun_and); + REGISTER_OP_HANDLER (bsxfun_builtin_or, btyp_bool, boolNDArray, bsxfun_or); + + // Register power handlers. + bsxfun_handler_table[bsxfun_builtin_power][btyp_double] = + do_bsxfun_real_pow<NDArray, ComplexNDArray>; + bsxfun_handler_table[bsxfun_builtin_power][btyp_float] = + do_bsxfun_real_pow<FloatNDArray, FloatComplexNDArray>; + + REGISTER_OP_HANDLER (bsxfun_builtin_power, btyp_complex, ComplexNDArray, bsxfun_pow); + REGISTER_OP_HANDLER (bsxfun_builtin_power, btyp_float_complex, FloatComplexNDArray, bsxfun_pow); + + // For chars, we want just relational handlers. + REGISTER_REL_HANDLER (bsxfun_builtin_eq, btyp_char, charNDArray, bsxfun_eq); + REGISTER_REL_HANDLER (bsxfun_builtin_ne, btyp_char, charNDArray, bsxfun_ne); + REGISTER_REL_HANDLER (bsxfun_builtin_lt, btyp_char, charNDArray, bsxfun_lt); + REGISTER_REL_HANDLER (bsxfun_builtin_le, btyp_char, charNDArray, bsxfun_le); + REGISTER_REL_HANDLER (bsxfun_builtin_gt, btyp_char, charNDArray, bsxfun_gt); + REGISTER_REL_HANDLER (bsxfun_builtin_ge, btyp_char, charNDArray, bsxfun_ge); + + filled = true; +} + +static octave_value +maybe_optimized_builtin (const std::string& name, + const octave_value& a, const octave_value& b) +{ + octave_value retval; + + maybe_fill_table (); + + bsxfun_builtin_op op = bsxfun_builtin_lookup (name); + if (op != bsxfun_builtin_unknown) + { + builtin_type_t btyp_a = a.builtin_type (), btyp_b = b.builtin_type (); + + // Simplify single/double combinations. + if (btyp_a == btyp_float && btyp_b == btyp_double) + btyp_b = btyp_float; + else if (btyp_a == btyp_double && btyp_b == btyp_float) + btyp_a = btyp_float; + else if (btyp_a == btyp_float_complex && btyp_b == btyp_complex) + btyp_b = btyp_float_complex; + else if (btyp_a == btyp_complex && btyp_b == btyp_float_complex) + btyp_a = btyp_float_complex; + + if (btyp_a == btyp_b && btyp_a != btyp_unknown) + { + bsxfun_handler handler = bsxfun_handler_table[op][btyp_a]; + if (handler) + retval = handler (a, b); + } + } + + return retval; +} + +static bool +maybe_update_column (octave_value& Ac, const octave_value& A, + const dim_vector& dva, const dim_vector& dvc, + octave_idx_type i, octave_value_list &idx) +{ + octave_idx_type nd = dva.length (); + + if (i == 0) + { + idx(0) = octave_value (':'); + for (octave_idx_type j = 1; j < nd; j++) + { + if (dva (j) == 1) + idx(j) = octave_value (1); + else + idx(j) = octave_value ((i % dvc(j)) + 1); + + i = i / dvc (j); + } + + Ac = A; + Ac = Ac.single_subsref ("(", idx); + return true; + } + else + { + bool is_changed = false; + octave_idx_type k = i; + octave_idx_type k1 = i - 1; + for (octave_idx_type j = 1; j < nd; j++) + { + if (dva(j) != 1 && k % dvc (j) != k1 % dvc (j)) + { + idx (j) = octave_value ((k % dvc(j)) + 1); + is_changed = true; + } + + k = k / dvc (j); + k1 = k1 / dvc (j); + } + + if (is_changed) + { + Ac = A; + Ac = Ac.single_subsref ("(", idx); + return true; + } + else + return false; + } +} + +#if 0 +// FIXME -- this function is not used; is it OK to delete it? +static void +update_index (octave_value_list& idx, const dim_vector& dv, octave_idx_type i) +{ + octave_idx_type nd = dv.length (); + + if (i == 0) + { + for (octave_idx_type j = nd - 1; j > 0; j--) + idx(j) = octave_value (static_cast<double>(1)); + idx(0) = octave_value (':'); + } + else + { + for (octave_idx_type j = 1; j < nd; j++) + { + idx (j) = octave_value (i % dv (j) + 1); + i = i / dv (j); + } + } +} +#endif + +static void +update_index (Array<int>& idx, const dim_vector& dv, octave_idx_type i) +{ + octave_idx_type nd = dv.length (); + + idx(0) = 0; + for (octave_idx_type j = 1; j < nd; j++) + { + idx (j) = i % dv (j); + i = i / dv (j); + } +} + +DEFUN (bsxfun, args, , + "-*- texinfo -*-\n\ +@deftypefn {Built-in Function} {} bsxfun (@var{f}, @var{A}, @var{B})\n\ +The binary singleton expansion function applier performs broadcasting,\n\ +that is, applies a binary function @var{f} element-by-element to two\n\ +array arguments @var{A} and @var{B}, and expands as necessary\n\ +singleton dimensions in either input argument. @var{f} is a function\n\ +handle, inline function, or string containing the name of the function\n\ +to evaluate. The function @var{f} must be capable of accepting two\n\ +column-vector arguments of equal length, or one column vector argument\n\ +and a scalar.\n\ +\n\ +The dimensions of @var{A} and @var{B} must be equal or singleton. The\n\ +singleton dimensions of the arrays will be expanded to the same\n\ +dimensionality as the other array.\n\ +@seealso{arrayfun, cellfun}\n\ +@end deftypefn") +{ + int nargin = args.length (); + octave_value_list retval; + + if (nargin != 3) + print_usage (); + else + { + octave_value func = args(0); + + if (func.is_string ()) + { + std::string name = func.string_value (); + func = symbol_table::find_function (name); + if (func.is_undefined ()) + error ("bsxfun: invalid function name: %s", name.c_str ()); + } + else if (! (args(0).is_function_handle () || args(0).is_inline_function ())) + error ("bsxfun: F must be a string or function handle"); + + const octave_value A = args (1); + const octave_value B = args (2); + + if (func.is_builtin_function () + || (func.is_function_handle () && ! A.is_object () && ! B.is_object ())) + { + // This may break if the default behavior is overriden. But if you override + // arithmetic operators for builtin classes, you should expect mayhem + // anyway (constant folding etc). Querying is_overloaded may not be + // exactly what we need here. + octave_function *fcn_val = func.function_value (); + if (fcn_val) + { + octave_value tmp = maybe_optimized_builtin (fcn_val->name (), A, B); + if (tmp.is_defined ()) + retval(0) = tmp; + } + } + + if (! error_state && retval.empty ()) + { + dim_vector dva = A.dims (); + octave_idx_type nda = dva.length (); + dim_vector dvb = B.dims (); + octave_idx_type ndb = dvb.length (); + octave_idx_type nd = nda; + + if (nda > ndb) + dvb.resize (nda, 1); + else if (nda < ndb) + { + dva.resize (ndb, 1); + nd = ndb; + } + + for (octave_idx_type i = 0; i < nd; i++) + if (dva (i) != dvb (i) && dva (i) != 1 && dvb (i) != 1) + { + error ("bsxfun: dimensions of A and B must match"); + break; + } + + if (!error_state) + { + // Find the size of the output + dim_vector dvc; + dvc.resize (nd); + + for (octave_idx_type i = 0; i < nd; i++) + dvc (i) = (dva (i) < 1 ? dva (i) : (dvb (i) < 1 ? dvb (i) : + (dva (i) > dvb (i) ? dva (i) : dvb (i)))); + + if (dva == dvb || dva.numel () == 1 || dvb.numel () == 1) + { + octave_value_list inputs; + inputs (0) = A; + inputs (1) = B; + retval = func.do_multi_index_op (1, inputs); + } + else if (dvc.numel () < 1) + { + octave_value_list inputs; + inputs (0) = A.resize (dvc); + inputs (1) = B.resize (dvc); + retval = func.do_multi_index_op (1, inputs); + } + else + { + octave_idx_type ncount = 1; + for (octave_idx_type i = 1; i < nd; i++) + ncount *= dvc (i); + +#define BSXDEF(T) \ + T result_ ## T; \ + bool have_ ## T = false; + + BSXDEF(NDArray); + BSXDEF(ComplexNDArray); + BSXDEF(FloatNDArray); + BSXDEF(FloatComplexNDArray); + BSXDEF(boolNDArray); + BSXDEF(int8NDArray); + BSXDEF(int16NDArray); + BSXDEF(int32NDArray); + BSXDEF(int64NDArray); + BSXDEF(uint8NDArray); + BSXDEF(uint16NDArray); + BSXDEF(uint32NDArray); + BSXDEF(uint64NDArray); + + octave_value Ac ; + octave_value_list idxA; + octave_value Bc; + octave_value_list idxB; + octave_value C; + octave_value_list inputs; + Array<int> ra_idx (dim_vector (dvc.length (), 1), 0); + + + for (octave_idx_type i = 0; i < ncount; i++) + { + if (maybe_update_column (Ac, A, dva, dvc, i, idxA)) + inputs (0) = Ac; + + if (maybe_update_column (Bc, B, dvb, dvc, i, idxB)) + inputs (1) = Bc; + + octave_value_list tmp = func.do_multi_index_op (1, inputs); + + if (error_state) + break; + +#define BSXINIT(T, CLS, EXTRACTOR) \ + (result_type == CLS) \ + { \ + have_ ## T = true; \ + result_ ## T = \ + tmp (0). EXTRACTOR ## _array_value (); \ + result_ ## T .resize (dvc); \ + } + + if (i == 0) + { + if (! tmp(0).is_sparse_type ()) + { + std::string result_type = tmp(0).class_name (); + if (result_type == "double") + { + if (tmp(0).is_real_type ()) + { + have_NDArray = true; + result_NDArray = tmp(0).array_value (); + result_NDArray.resize (dvc); + } + else + { + have_ComplexNDArray = true; + result_ComplexNDArray = + tmp(0).complex_array_value (); + result_ComplexNDArray.resize (dvc); + } + } + else if (result_type == "single") + { + if (tmp(0).is_real_type ()) + { + have_FloatNDArray = true; + result_FloatNDArray = tmp(0).float_array_value (); + result_FloatNDArray.resize (dvc); + } + else + { + have_ComplexNDArray = true; + result_ComplexNDArray = + tmp(0).complex_array_value (); + result_ComplexNDArray.resize (dvc); + } + } + else if BSXINIT(boolNDArray, "logical", bool) + else if BSXINIT(int8NDArray, "int8", int8) + else if BSXINIT(int16NDArray, "int16", int16) + else if BSXINIT(int32NDArray, "int32", int32) + else if BSXINIT(int64NDArray, "int64", int64) + else if BSXINIT(uint8NDArray, "uint8", uint8) + else if BSXINIT(uint16NDArray, "uint16", uint16) + else if BSXINIT(uint32NDArray, "uint32", uint32) + else if BSXINIT(uint64NDArray, "uint64", uint64) + else + { + C = tmp (0); + C = C.resize (dvc); + } + } + } + else + { + update_index (ra_idx, dvc, i); + + if (have_FloatNDArray || + have_FloatComplexNDArray) + { + if (! tmp(0).is_float_type ()) + { + if (have_FloatNDArray) + { + have_FloatNDArray = false; + C = result_FloatNDArray; + } + else + { + have_FloatComplexNDArray = false; + C = result_FloatComplexNDArray; + } + C = do_cat_op (C, tmp(0), ra_idx); + } + else if (tmp(0).is_double_type ()) + { + if (tmp(0).is_complex_type () && + have_FloatNDArray) + { + result_ComplexNDArray = + ComplexNDArray (result_FloatNDArray); + result_ComplexNDArray.insert + (tmp(0).complex_array_value (), ra_idx); + have_FloatComplexNDArray = false; + have_ComplexNDArray = true; + } + else + { + result_NDArray = + NDArray (result_FloatNDArray); + result_NDArray.insert + (tmp(0).array_value (), ra_idx); + have_FloatNDArray = false; + have_NDArray = true; + } + } + else if (tmp(0).is_real_type ()) + result_FloatNDArray.insert + (tmp(0).float_array_value (), ra_idx); + else + { + result_FloatComplexNDArray = + FloatComplexNDArray (result_FloatNDArray); + result_FloatComplexNDArray.insert + (tmp(0).float_complex_array_value (), ra_idx); + have_FloatNDArray = false; + have_FloatComplexNDArray = true; + } + } + else if (have_NDArray) + { + if (! tmp(0).is_float_type ()) + { + have_NDArray = false; + C = result_NDArray; + C = do_cat_op (C, tmp(0), ra_idx); + } + else if (tmp(0).is_real_type ()) + result_NDArray.insert (tmp(0).array_value (), + ra_idx); + else + { + result_ComplexNDArray = + ComplexNDArray (result_NDArray); + result_ComplexNDArray.insert + (tmp(0).complex_array_value (), ra_idx); + have_NDArray = false; + have_ComplexNDArray = true; + } + } + +#define BSXLOOP(T, CLS, EXTRACTOR) \ + (have_ ## T) \ + { \ + if (tmp (0).class_name () != CLS) \ + { \ + have_ ## T = false; \ + C = result_ ## T; \ + C = do_cat_op (C, tmp (0), ra_idx); \ + } \ + else \ + result_ ## T .insert \ + (tmp(0). EXTRACTOR ## _array_value (), \ + ra_idx); \ + } + + else if BSXLOOP(ComplexNDArray, "double", complex) + else if BSXLOOP(boolNDArray, "logical", bool) + else if BSXLOOP(int8NDArray, "int8", int8) + else if BSXLOOP(int16NDArray, "int16", int16) + else if BSXLOOP(int32NDArray, "int32", int32) + else if BSXLOOP(int64NDArray, "int64", int64) + else if BSXLOOP(uint8NDArray, "uint8", uint8) + else if BSXLOOP(uint16NDArray, "uint16", uint16) + else if BSXLOOP(uint32NDArray, "uint32", uint32) + else if BSXLOOP(uint64NDArray, "uint64", uint64) + else + C = do_cat_op (C, tmp(0), ra_idx); + } + } + +#define BSXEND(T) \ + (have_ ## T) \ + retval(0) = result_ ## T; + + if BSXEND(NDArray) + else if BSXEND(ComplexNDArray) + else if BSXEND(FloatNDArray) + else if BSXEND(FloatComplexNDArray) + else if BSXEND(boolNDArray) + else if BSXEND(int8NDArray) + else if BSXEND(int16NDArray) + else if BSXEND(int32NDArray) + else if BSXEND(int64NDArray) + else if BSXEND(uint8NDArray) + else if BSXEND(uint16NDArray) + else if BSXEND(uint32NDArray) + else if BSXEND(uint64NDArray) + else + retval(0) = C; + } + } + } + } + + return retval; +} + +/* + +%!shared a, b, c, f +%! a = randn (4, 4); +%! b = mean (a, 1); +%! c = mean (a, 2); +%! f = @minus; +%!error (bsxfun (f)) +%!error (bsxfun (f, a)) +%!error (bsxfun (a, b)) +%!error (bsxfun (a, b, c)) +%!error (bsxfun (f, a, b, c)) +%!error (bsxfun (f, ones (4, 0), ones (4, 4))) +%!assert (bsxfun (f, ones (4, 0), ones (4, 1)), zeros (4, 0)) +%!assert (bsxfun (f, ones (1, 4), ones (4, 1)), zeros (4, 4)) +%!assert (bsxfun (f, a, b), a - repmat (b, 4, 1)) +%!assert (bsxfun (f, a, c), a - repmat (c, 1, 4)) +%!assert (bsxfun ("minus", ones (1, 4), ones (4, 1)), zeros (4, 4)) + +%!shared a, b, c, f +%! a = randn (4, 4); +%! a(1) *= 1i; +%! b = mean (a, 1); +%! c = mean (a, 2); +%! f = @minus; +%!error (bsxfun (f)) +%!error (bsxfun (f, a)) +%!error (bsxfun (a, b)) +%!error (bsxfun (a, b, c)) +%!error (bsxfun (f, a, b, c)) +%!error (bsxfun (f, ones (4, 0), ones (4, 4))) +%!assert (bsxfun (f, ones (4, 0), ones (4, 1)), zeros (4, 0)) +%!assert (bsxfun (f, ones (1, 4), ones (4, 1)), zeros (4, 4)) +%!assert (bsxfun (f, a, b), a - repmat (b, 4, 1)) +%!assert (bsxfun (f, a, c), a - repmat (c, 1, 4)) +%!assert (bsxfun ("minus", ones (1, 4), ones (4, 1)), zeros (4, 4)) + +%!shared a, b, c, f +%! a = randn (4, 4); +%! a(end) *= 1i; +%! b = mean (a, 1); +%! c = mean (a, 2); +%! f = @minus; +%!error (bsxfun (f)) +%!error (bsxfun (f, a)) +%!error (bsxfun (a, b)) +%!error (bsxfun (a, b, c)) +%!error (bsxfun (f, a, b, c)) +%!error (bsxfun (f, ones (4, 0), ones (4, 4))) +%!assert (bsxfun (f, ones (4, 0), ones (4, 1)), zeros (4, 0)) +%!assert (bsxfun (f, ones (1, 4), ones (4, 1)), zeros (4, 4)) +%!assert (bsxfun (f, a, b), a - repmat (b, 4, 1)) +%!assert (bsxfun (f, a, c), a - repmat (c, 1, 4)) +%!assert (bsxfun ("minus", ones (1, 4), ones (4, 1)), zeros (4, 4)) + +%!shared a, b, c, f +%! a = randn (4, 4); +%! b = a (1, :); +%! c = a (:, 1); +%! f = @(x, y) x == y; +%!error (bsxfun (f)) +%!error (bsxfun (f, a)) +%!error (bsxfun (a, b)) +%!error (bsxfun (a, b, c)) +%!error (bsxfun (f, a, b, c)) +%!error (bsxfun (f, ones (4, 0), ones (4, 4))) +%!assert (bsxfun (f, ones (4, 0), ones (4, 1)), zeros (4, 0, "logical")) +%!assert (bsxfun (f, ones (1, 4), ones (4, 1)), ones (4, 4, "logical")) +%!assert (bsxfun (f, a, b), a == repmat (b, 4, 1)) +%!assert (bsxfun (f, a, c), a == repmat (c, 1, 4)) + +%!shared a, b, c, d, f +%! a = randn (4, 4, 4); +%! b = mean (a, 1); +%! c = mean (a, 2); +%! d = mean (a, 3); +%! f = @minus; +%!error (bsxfun (f, ones ([4, 0, 4]), ones ([4, 4, 4]))) +%!assert (bsxfun (f, ones ([4, 0, 4]), ones ([4, 1, 4])), zeros ([4, 0, 4])) +%!assert (bsxfun (f, ones ([4, 4, 0]), ones ([4, 1, 1])), zeros ([4, 4, 0])) +%!assert (bsxfun (f, ones ([1, 4, 4]), ones ([4, 1, 4])), zeros ([4, 4, 4])) +%!assert (bsxfun (f, ones ([4, 4, 1]), ones ([4, 1, 4])), zeros ([4, 4, 4])) +%!assert (bsxfun (f, ones ([4, 1, 4]), ones ([1, 4, 4])), zeros ([4, 4, 4])) +%!assert (bsxfun (f, ones ([4, 1, 4]), ones ([1, 4, 1])), zeros ([4, 4, 4])) +%!assert (bsxfun (f, a, b), a - repmat (b, [4, 1, 1])) +%!assert (bsxfun (f, a, c), a - repmat (c, [1, 4, 1])) +%!assert (bsxfun (f, a, d), a - repmat (d, [1, 1, 4])) +%!assert (bsxfun ("minus", ones ([4, 0, 4]), ones ([4, 1, 4])), zeros ([4, 0, 4])) + +%% The test below is a very hard case to treat +%!assert (bsxfun (f, ones ([4, 1, 4, 1]), ones ([1, 4, 1, 4])), zeros ([4, 4, 4, 4])); + +%!shared a, b, aa, bb +%! a = randn (3, 1, 3); +%! aa = a(:, ones (1, 3), :, ones (1, 3)); +%! b = randn (1, 3, 3, 3); +%! bb = b(ones (1, 3), :, :, :); +%!assert (bsxfun (@plus, a, b), aa + bb) +%!assert (bsxfun (@minus, a, b), aa - bb) +%!assert (bsxfun (@times, a, b), aa .* bb) +%!assert (bsxfun (@rdivide, a, b), aa ./ bb) +%!assert (bsxfun (@ldivide, a, b), aa .\ bb) +%!assert (bsxfun (@power, a, b), aa .^ bb) +%!assert (bsxfun (@power, abs (a), b), abs (aa) .^ bb) +%!assert (bsxfun (@eq, round (a), round (b)), round (aa) == round (bb)) +%!assert (bsxfun (@ne, round (a), round (b)), round (aa) != round (bb)) +%!assert (bsxfun (@lt, a, b), aa < bb) +%!assert (bsxfun (@le, a, b), aa <= bb) +%!assert (bsxfun (@gt, a, b), aa > bb) +%!assert (bsxfun (@ge, a, b), aa >= bb) +%!assert (bsxfun (@min, a, b), min (aa, bb)) +%!assert (bsxfun (@max, a, b), max (aa, bb)) +%!assert (bsxfun (@and, a > 0, b > 0), (aa > 0) & (bb > 0)) +%!assert (bsxfun (@or, a > 0, b > 0), (aa > 0) | (bb > 0)) + +%% Test automatic bsxfun +% +%!test +%! funs = {@plus, @minus, @times, @rdivide, @ldivide, @power, @max, @min, \ +%! @rem, @mod, @atan2, @hypot, @eq, @ne, @lt, @le, @gt, @ge, \ +%! @and, @or, @xor }; +%! +%! float_types = {@single, @double}; +%! int_types = {@int8, @int16, @int32, @int64, \ +%! @uint8, @uint16, @uint32, @uint64}; +%! +%! x = rand (3) * 10-5; +%! y = rand (3,1) * 10-5; +%! +%! for i=1:length (funs) +%! for j = 1:length (float_types) +%! for k = 1:length (int_types) +%! +%! fun = funs{i}; +%! f_type = float_types{j}; +%! i_type = int_types{k}; +%! +%! assert (bsxfun (fun, f_type (x), i_type (y)), \ +%! fun (f_type(x), i_type (y))); +%! assert (bsxfun (fun, f_type (y), i_type (x)), \ +%! fun (f_type(y), i_type (x))); +%! +%! assert (bsxfun (fun, i_type (x), i_type (y)), \ +%! fun (i_type (x), i_type (y))); +%! assert (bsxfun (fun, i_type (y), i_type (x)), \ +%! fun (i_type (y), i_type (x))); +%! +%! assert (bsxfun (fun, f_type (x), f_type (y)), \ +%! fun (f_type (x), f_type (y))); +%! assert (bsxfun (fun, f_type(y), f_type(x)), \ +%! fun (f_type (y), f_type (x))); +%! endfor +%! endfor +%! endfor +%! +*/