changeset 25567:8c0a35040974 stable

Fix blkmm to work with empty matrices (bug #54261). * dot.cc (blkmm_internal): New templated function with overrides to call the correct Fortran routine based on template Array type. * dot.cc (get_blkmm_dims): New function to verify dimensions of arguments to Fblkmm. * dot.cc (do_blkmm): Templated function which uses octave_value_extract to get the correct data type before calling blkmm_internal. * dot.cc (Fblkmm): Simplify function by extracting most of the actions into get_blkmm_dims and blkmm_internal. New function determines input data type and calls templated do_blkmm appropriately. Add BIST test for bug #54261.
author John W. Eaton <jwe@octave.org>
date Mon, 09 Jul 2018 15:53:05 -0700
parents 65a28dec405b
children e6b037514f56 e8961d677661
files libinterp/corefcn/dot.cc
diffstat 1 files changed, 116 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/dot.cc	Thu Jul 05 14:17:13 2018 -0400
+++ b/libinterp/corefcn/dot.cc	Mon Jul 09 15:53:05 2018 -0700
@@ -227,6 +227,107 @@
 %!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
 */
 
+template <typename T>
+void
+blkmm_internal (const T& x, const T& y, T& z,
+                F77_INT m, F77_INT n, F77_INT k, F77_INT np);
+
+template <>
+void
+blkmm_internal (const FloatComplexNDArray& x, const FloatComplexNDArray& y,
+                FloatComplexNDArray& z,
+                F77_INT m, F77_INT n, F77_INT k, F77_INT np)
+{
+  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
+                             F77_CONST_CMPLX_ARG (x.data ()),
+                             F77_CONST_CMPLX_ARG (y.data ()),
+                             F77_CMPLX_ARG (z.fortran_vec ())));
+}
+
+template <>
+void
+blkmm_internal (const ComplexNDArray& x, const ComplexNDArray& y,
+                ComplexNDArray& z,
+                F77_INT m, F77_INT n, F77_INT k, F77_INT np)
+{
+  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
+                             F77_CONST_DBLE_CMPLX_ARG (x.data ()),
+                             F77_CONST_DBLE_CMPLX_ARG (y.data ()),
+                             F77_DBLE_CMPLX_ARG (z.fortran_vec ())));
+}
+
+template <>
+void
+blkmm_internal (const FloatNDArray& x, const FloatNDArray& y, FloatNDArray& z,
+                F77_INT m, F77_INT n, F77_INT k, F77_INT np)
+{
+  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
+                             x.data (), y.data (),
+                             z.fortran_vec ()));
+}
+
+template <>
+void
+blkmm_internal (const NDArray& x, const NDArray& y, NDArray& z,
+                F77_INT m, F77_INT n, F77_INT k, F77_INT np)
+{
+  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
+                             x.data (), y.data (),
+                             z.fortran_vec ()));
+}
+
+static void
+get_blkmm_dims (const dim_vector& dimx, const dim_vector& dimy,
+                F77_INT& m, F77_INT& n, F77_INT& k, F77_INT& np,
+                dim_vector& dimz)
+{
+  int nd = dimx.ndims ();
+
+  m = octave::to_f77_int (dimx(0));
+  k = octave::to_f77_int (dimx(1));
+  n = octave::to_f77_int (dimy(1));
+
+  octave_idx_type tmp_np = 1;
+
+  bool match = dimy(0) == k && nd == dimy.ndims ();
+
+  dimz = dim_vector::alloc (nd);
+
+  dimz(0) = m;
+  dimz(1) = n;
+  for (int i = 2; match && i < nd; i++)
+    {
+      match = match && dimx(i) == dimy(i);
+      dimz(i) = dimx(i);
+      tmp_np *= dimz(i);
+    }
+
+  np = octave::to_f77_int (tmp_np);
+
+  if (! match)
+    error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
+           dimx.str ().c_str (), dimy.str ().c_str ());
+}
+
+template <typename T>
+T
+do_blkmm (const octave_value& xov, const octave_value& yov)
+{
+  const T x = octave_value_extract<T> (xov); 
+  const T y = octave_value_extract<T> (yov); 
+  F77_INT m, n, k, np;
+  dim_vector dimz;
+
+  get_blkmm_dims (x.dims (), y.dims (), m, n, k, np, dimz);
+
+  T z (dimz);
+
+  if (n != 0 && m != 0)
+    blkmm_internal<T> (x, y, z, m, n, k, np);
+
+  return z;
+}
+
 DEFUN (blkmm, args, ,
        doc: /* -*- texinfo -*-
 @deftypefn {} {} blkmm (@var{A}, @var{B})
@@ -257,78 +358,19 @@
   if (! argx.isnumeric () || ! argy.isnumeric ())
     error ("blkmm: A and B must be numeric");
 
-  const dim_vector dimx = argx.dims ();
-  const dim_vector dimy = argy.dims ();
-  int nd = dimx.ndims ();
-  F77_INT m = octave::to_f77_int (dimx(0));
-  F77_INT k = octave::to_f77_int (dimx(1));
-  F77_INT n = octave::to_f77_int (dimy(1));
-  octave_idx_type tmp_np = 1;
-  bool match = dimy(0) == k && nd == dimy.ndims ();
-  dim_vector dimz = dim_vector::alloc (nd);
-  dimz(0) = m;
-  dimz(1) = n;
-  for (int i = 2; match && i < nd; i++)
-    {
-      match = match && dimx(i) == dimy(i);
-      dimz(i) = dimx(i);
-      tmp_np *= dimz(i);
-    }
-  F77_INT np = octave::to_f77_int (tmp_np);
-
-  if (! match)
-    error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
-           dimx.str ().c_str (), dimy.str ().c_str ());
-
   if (argx.iscomplex () || argy.iscomplex ())
     {
       if (argx.is_single_type () || argy.is_single_type ())
-        {
-          FloatComplexNDArray x = argx.float_complex_array_value ();
-          FloatComplexNDArray y = argy.float_complex_array_value ();
-          FloatComplexNDArray z (dimz);
-
-          F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
-                                     F77_CONST_CMPLX_ARG (x.data ()), F77_CONST_CMPLX_ARG (y.data ()),
-                                     F77_CMPLX_ARG (z.fortran_vec ())));
-          retval = z;
-        }
+        retval = do_blkmm<FloatComplexNDArray> (argx, argy);
       else
-        {
-          ComplexNDArray x = argx.complex_array_value ();
-          ComplexNDArray y = argy.complex_array_value ();
-          ComplexNDArray z (dimz);
-
-          F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
-                                     F77_CONST_DBLE_CMPLX_ARG (x.data ()), F77_CONST_DBLE_CMPLX_ARG (y.data ()),
-                                     F77_DBLE_CMPLX_ARG (z.fortran_vec ())));
-          retval = z;
-        }
+        retval = do_blkmm<ComplexNDArray> (argx, argy);
     }
   else
     {
       if (argx.is_single_type () || argy.is_single_type ())
-        {
-          FloatNDArray x = argx.float_array_value ();
-          FloatNDArray y = argy.float_array_value ();
-          FloatNDArray z (dimz);
-
-          F77_XFCN (smatm3, SMATM3, (m, n, k, np,
-                                     x.data (), y.data (),
-                                     z.fortran_vec ()));
-          retval = z;
-        }
+        retval = do_blkmm<FloatNDArray> (argx, argy);
       else
-        {
-          NDArray x = argx.array_value ();
-          NDArray y = argy.array_value ();
-          NDArray z (dimz);
-
-          F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
-                                     x.data (), y.data (),
-                                     z.fortran_vec ()));
-          retval = z;
-        }
+        retval = do_blkmm<NDArray> (argx, argy);
     }
 
   return retval;
@@ -353,11 +395,21 @@
 %! assert (blkmm (single (x), single (x)), single (z));
 %! assert (blkmm (x, single (x)), single (z));
 
+%!test <*54261>
+%! x = ones (0, 3, 3);
+%! y = ones (3, 5, 3);
+%! z = blkmm (x,y);
+%! assert (size (z), [0, 5, 3]);
+%! x = ones (1, 3, 3);
+%! y = ones (3, 0, 3);
+%! z = blkmm (x,y);
+%! assert (size (z), [1, 0, 3]);
+
 ## Test input validation
 %!error blkmm ()
 %!error blkmm (1)
 %!error blkmm (1,2,3)
-%!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
 %!error <A and B must be numeric> blkmm ({1,2}, [3,4])
 %!error <A and B must be numeric> blkmm ([3,4], {1,2})
+%!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
 */