# HG changeset patch # User John W. Eaton # Date 1531176785 25200 # Node ID 8c0a350409742fdbbdef7b0e6a63f371f6a6d377 # Parent 65a28dec405b77afd828fee864ffa7a5c0788a11 Fix blkmm to work with empty matrices (bug #54261). * dot.cc (blkmm_internal): New templated function with overrides to call the correct Fortran routine based on template Array type. * dot.cc (get_blkmm_dims): New function to verify dimensions of arguments to Fblkmm. * dot.cc (do_blkmm): Templated function which uses octave_value_extract to get the correct data type before calling blkmm_internal. * dot.cc (Fblkmm): Simplify function by extracting most of the actions into get_blkmm_dims and blkmm_internal. New function determines input data type and calls templated do_blkmm appropriately. Add BIST test for bug #54261. diff -r 65a28dec405b -r 8c0a35040974 libinterp/corefcn/dot.cc --- a/libinterp/corefcn/dot.cc Thu Jul 05 14:17:13 2018 -0400 +++ b/libinterp/corefcn/dot.cc Mon Jul 09 15:53:05 2018 -0700 @@ -227,6 +227,107 @@ %!error dot ([1 2], [1 2], 0) */ +template +void +blkmm_internal (const T& x, const T& y, T& z, + F77_INT m, F77_INT n, F77_INT k, F77_INT np); + +template <> +void +blkmm_internal (const FloatComplexNDArray& x, const FloatComplexNDArray& y, + FloatComplexNDArray& z, + F77_INT m, F77_INT n, F77_INT k, F77_INT np) +{ + F77_XFCN (cmatm3, CMATM3, (m, n, k, np, + F77_CONST_CMPLX_ARG (x.data ()), + F77_CONST_CMPLX_ARG (y.data ()), + F77_CMPLX_ARG (z.fortran_vec ()))); +} + +template <> +void +blkmm_internal (const ComplexNDArray& x, const ComplexNDArray& y, + ComplexNDArray& z, + F77_INT m, F77_INT n, F77_INT k, F77_INT np) +{ + F77_XFCN (zmatm3, ZMATM3, (m, n, k, np, + F77_CONST_DBLE_CMPLX_ARG (x.data ()), + F77_CONST_DBLE_CMPLX_ARG (y.data ()), + F77_DBLE_CMPLX_ARG (z.fortran_vec ()))); +} + +template <> +void +blkmm_internal (const FloatNDArray& x, const FloatNDArray& y, FloatNDArray& z, + F77_INT m, F77_INT n, F77_INT k, F77_INT np) +{ + F77_XFCN (smatm3, SMATM3, (m, n, k, np, + x.data (), y.data (), + z.fortran_vec ())); +} + +template <> +void +blkmm_internal (const NDArray& x, const NDArray& y, NDArray& z, + F77_INT m, F77_INT n, F77_INT k, F77_INT np) +{ + F77_XFCN (dmatm3, DMATM3, (m, n, k, np, + x.data (), y.data (), + z.fortran_vec ())); +} + +static void +get_blkmm_dims (const dim_vector& dimx, const dim_vector& dimy, + F77_INT& m, F77_INT& n, F77_INT& k, F77_INT& np, + dim_vector& dimz) +{ + int nd = dimx.ndims (); + + m = octave::to_f77_int (dimx(0)); + k = octave::to_f77_int (dimx(1)); + n = octave::to_f77_int (dimy(1)); + + octave_idx_type tmp_np = 1; + + bool match = dimy(0) == k && nd == dimy.ndims (); + + dimz = dim_vector::alloc (nd); + + dimz(0) = m; + dimz(1) = n; + for (int i = 2; match && i < nd; i++) + { + match = match && dimx(i) == dimy(i); + dimz(i) = dimx(i); + tmp_np *= dimz(i); + } + + np = octave::to_f77_int (tmp_np); + + if (! match) + error ("blkmm: A and B dimensions don't match: (%s) and (%s)", + dimx.str ().c_str (), dimy.str ().c_str ()); +} + +template +T +do_blkmm (const octave_value& xov, const octave_value& yov) +{ + const T x = octave_value_extract (xov); + const T y = octave_value_extract (yov); + F77_INT m, n, k, np; + dim_vector dimz; + + get_blkmm_dims (x.dims (), y.dims (), m, n, k, np, dimz); + + T z (dimz); + + if (n != 0 && m != 0) + blkmm_internal (x, y, z, m, n, k, np); + + return z; +} + DEFUN (blkmm, args, , doc: /* -*- texinfo -*- @deftypefn {} {} blkmm (@var{A}, @var{B}) @@ -257,78 +358,19 @@ if (! argx.isnumeric () || ! argy.isnumeric ()) error ("blkmm: A and B must be numeric"); - const dim_vector dimx = argx.dims (); - const dim_vector dimy = argy.dims (); - int nd = dimx.ndims (); - F77_INT m = octave::to_f77_int (dimx(0)); - F77_INT k = octave::to_f77_int (dimx(1)); - F77_INT n = octave::to_f77_int (dimy(1)); - octave_idx_type tmp_np = 1; - bool match = dimy(0) == k && nd == dimy.ndims (); - dim_vector dimz = dim_vector::alloc (nd); - dimz(0) = m; - dimz(1) = n; - for (int i = 2; match && i < nd; i++) - { - match = match && dimx(i) == dimy(i); - dimz(i) = dimx(i); - tmp_np *= dimz(i); - } - F77_INT np = octave::to_f77_int (tmp_np); - - if (! match) - error ("blkmm: A and B dimensions don't match: (%s) and (%s)", - dimx.str ().c_str (), dimy.str ().c_str ()); - if (argx.iscomplex () || argy.iscomplex ()) { if (argx.is_single_type () || argy.is_single_type ()) - { - FloatComplexNDArray x = argx.float_complex_array_value (); - FloatComplexNDArray y = argy.float_complex_array_value (); - FloatComplexNDArray z (dimz); - - F77_XFCN (cmatm3, CMATM3, (m, n, k, np, - F77_CONST_CMPLX_ARG (x.data ()), F77_CONST_CMPLX_ARG (y.data ()), - F77_CMPLX_ARG (z.fortran_vec ()))); - retval = z; - } + retval = do_blkmm (argx, argy); else - { - ComplexNDArray x = argx.complex_array_value (); - ComplexNDArray y = argy.complex_array_value (); - ComplexNDArray z (dimz); - - F77_XFCN (zmatm3, ZMATM3, (m, n, k, np, - F77_CONST_DBLE_CMPLX_ARG (x.data ()), F77_CONST_DBLE_CMPLX_ARG (y.data ()), - F77_DBLE_CMPLX_ARG (z.fortran_vec ()))); - retval = z; - } + retval = do_blkmm (argx, argy); } else { if (argx.is_single_type () || argy.is_single_type ()) - { - FloatNDArray x = argx.float_array_value (); - FloatNDArray y = argy.float_array_value (); - FloatNDArray z (dimz); - - F77_XFCN (smatm3, SMATM3, (m, n, k, np, - x.data (), y.data (), - z.fortran_vec ())); - retval = z; - } + retval = do_blkmm (argx, argy); else - { - NDArray x = argx.array_value (); - NDArray y = argy.array_value (); - NDArray z (dimz); - - F77_XFCN (dmatm3, DMATM3, (m, n, k, np, - x.data (), y.data (), - z.fortran_vec ())); - retval = z; - } + retval = do_blkmm (argx, argy); } return retval; @@ -353,11 +395,21 @@ %! assert (blkmm (single (x), single (x)), single (z)); %! assert (blkmm (x, single (x)), single (z)); +%!test <*54261> +%! x = ones (0, 3, 3); +%! y = ones (3, 5, 3); +%! z = blkmm (x,y); +%! assert (size (z), [0, 5, 3]); +%! x = ones (1, 3, 3); +%! y = ones (3, 0, 3); +%! z = blkmm (x,y); +%! assert (size (z), [1, 0, 3]); + ## Test input validation %!error blkmm () %!error blkmm (1) %!error blkmm (1,2,3) -%!error blkmm (ones (2,2), ones (3,3)) %!error blkmm ({1,2}, [3,4]) %!error blkmm ([3,4], {1,2}) +%!error blkmm (ones (2,2), ones (3,3)) */