changeset 25668:e84d53ffcae5 jwe

tmp template commit
author John W. Eaton <jwe@octave.org>
date Tue, 08 Nov 2016 15:06:28 -0500
parents 8b8832ce72b2
children
files liboctave/array/MArray.h liboctave/operators/Sparse-op-defs.h
diffstat 2 files changed, 216 insertions(+), 75 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/array/MArray.h	Mon Nov 07 11:11:26 2016 -0500
+++ b/liboctave/array/MArray.h	Tue Nov 08 15:06:28 2016 -0500
@@ -54,6 +54,64 @@
 template <typename T> MArray<T> quotient (const MArray<T>&, const MArray<T>&);
 template <typename T> MArray<T> product (const MArray<T>&, const MArray<T>&);
 
+template <typename T1, typename T2, typename OP>
+auto
+mm_bin_op (const MArray<T1>& x, const MArray<T2>& y, OP op) -> MArray<decltype (op (T1 (), T2 ()))>
+{
+  typedef decltype (op (T1 (), T2 ())) RT;
+
+  dim_vector dx = x.dims ();
+  dim_vector dy = y.dims ();
+
+  if (dx == dy)
+    {
+      Array<R> r (dx);
+      op (r.numel (), r.fortran_vec (), x.data (), y.data ());
+      return r;
+    }
+  else if (is_valid_bsxfun (opname, dx, dy))
+    {
+      return do_bsxfun_op (x, y, op, op1, op2);
+    }
+  else
+    octave::err_nonconformant (opname, dx, dy);
+
+  MArray<RT> r (x.dims ());
+
+  for (octave_idx_type i = 0; i < n; i++)
+    r[i] = op (x[i], y[i]);
+
+  return r;
+}
+
+template <typename MT, typename ST, typename OP>
+auto
+ms_bin_op (const MArray<MT>& x, const ST& y, OP op) -> MArray<decltype (op (MT (), ST ()))>
+{
+  typedef decltype (op (MT (), ST ())) RT;
+
+  MArray<RT> r (x.dims ());
+
+  for (octave_idx_type i = 0; i < n; i++)
+    r[i] = op (x[i], y);
+
+  return r;
+}
+
+template <typename MT, typename ST, typename OP>
+auto
+sm_bin_op (const ST& x, const MArray<MT>& y, OP op) -> MArray<decltype (op (ST (), MT ()))>
+{
+  typedef decltype (op (MT (), ST ())) RT;
+
+  MArray<RT> r (x.dims ());
+
+  for (octave_idx_type i = 0; i < n; i++)
+    r[i] = op (x, y[i]);
+
+  return r;
+}
+
 //! Template for N-dimensional array classes with like-type math operators.
 template <typename T>
 class
--- a/liboctave/operators/Sparse-op-defs.h	Mon Nov 07 11:11:26 2016 -0500
+++ b/liboctave/operators/Sparse-op-defs.h	Tue Nov 08 15:06:28 2016 -0500
@@ -714,83 +714,166 @@
 
 // matrix by sparse matrix operations.
 
-#define SPARSE_MSM_BIN_OP_1(R, F, OP, M1, M2)                           \
-  R                                                                     \
-  F (const M1& m1, const M2& m2)                                        \
-  {                                                                     \
-    R r;                                                                \
-                                                                        \
-    octave_idx_type m1_nr = m1.rows ();                                 \
-    octave_idx_type m1_nc = m1.cols ();                                 \
-                                                                        \
-    octave_idx_type m2_nr = m2.rows ();                                 \
-    octave_idx_type m2_nc = m2.cols ();                                 \
-                                                                        \
-    if (m2_nr == 1 && m2_nc == 1)                                       \
-      r = R (m1 OP m2.elem (0,0));                                      \
-    else if (m1_nr != m2_nr || m1_nc != m2_nc)                          \
-      octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc);       \
-    else                                                                \
-      {                                                                 \
-        r = R (F (m1, m2.matrix_value ()));                             \
-      }                                                                 \
-    return r;                                                           \
+template <typename M1, typename M2>
+auto
+msm_add_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () + M2 ())>
+{
+  typedef decltype (M1 () + M2 ()) RT;
+
+  MArray<RT> r;
+
+  octave_idx_type m1_nr = m1.rows ();
+  octave_idx_type m1_nc = m1.cols ();
+
+  octave_idx_type m2_nr = m2.rows ();
+  octave_idx_type m2_nc = m2.cols ();
+
+  if (m2_nr == 1 && m2_nc == 1)
+    r = MArray<RT> (m1 + m2.elem (0,0));
+  else if (m1_nr != m2_nr || m1_nc != m2_nc)
+    octave::err_nonconformant ("operator +", m1_nr, m1_nc, m2_nr, m2_nc);
+  else
+    r = m1 + MArray<M2> (m2.array_value ());
+
+  return r;
+}
+
+template <typename M1, typename M2>
+auto
+msm_sub_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () - M2 ())>
+{
+  MArray<decltype (M1 () - M2 ())> r;
+
+  octave_idx_type m1_nr = m1.rows ();
+  octave_idx_type m1_nc = m1.cols ();
+
+  octave_idx_type m2_nr = m2.rows ();
+  octave_idx_type m2_nc = m2.cols ();
+
+  if (m2_nr == 1 && m2_nc == 1)
+    r = m1 - m2.elem (0,0);
+  else if (m1_nr != m2_nr || m1_nc != m2_nc)
+    octave::err_nonconformant ("operator -", m1_nr, m1_nc, m2_nr, m2_nc);
+  else
+    r = m1 - MArray<M2> (m2.array_value ());
+
+  return r;
+}
+
+template <typename M1, typename M2>
+auto
+msm_mul_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () * M2 ())>
+{
+  typedef decltype (M1 () * M2 ()) RT;
+
+  MSparse<RT> r;
+
+  octave_idx_type m1_nr = m1.rows ();
+  octave_idx_type m1_nc = m1.cols ();
+
+  octave_idx_type m2_nr = m2.rows ();
+  octave_idx_type m2_nc = m2.cols ();
+
+  if (m2_nr == 1 && m2_nc == 1)
+    r = MSparse<RT> (m1 * m2.elem (0,0));
+  else if (m1_nr != m2_nr || m1_nc != m2_nc)
+    octave::err_nonconformant ("operator *", m1_nr, m1_nc, m2_nr, m2_nc);
+  else
+    {
+      if (do_mx_check (m1, mx_inline_all_finite<M1>))
+        {
+          /* Sparsity pattern is preserved. */
+          octave_idx_type m2_nz = m2.nnz ();
+          r = MSparse<RT> (m2_nr, m2_nc, m2_nz);
+          for (octave_idx_type j = 0, k = 0; j < m2_nc; j++)
+            {
+              octave_quit ();
+              for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++)
+                {
+                  octave_idx_type mri = m2.ridx (i);
+                  RT x = m1(mri, j) * m2.data (i);
+                  if (x != 0.0)
+                    {
+                      r.xdata (k) = x;
+                      r.xridx (k) = m2.ridx (i);
+                      k++;
+                    }
+                }
+              r.xcidx (j+1) = k;
+            }
+          r.maybe_compress (false);
+          return r;
+        }
+      else
+        r = MSparse<RT> (product (m1, MArray<M2> (m2.array_value ())));
+    }
+
+  return r;
+}
+
+template <typename M1, typename M2>
+auto
+msm_div_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () / M2 ())>
+{
+  typedef decltype (M1 () / M2 ()) RT;
+
+  MSparse<RT> r;
+
+  octave_idx_type m1_nr = m1.rows ();
+  octave_idx_type m1_nc = m1.cols ();
+
+  octave_idx_type m2_nr = m2.rows ();
+  octave_idx_type m2_nc = m2.cols ();
+
+  if (m2_nr == 1 && m2_nc == 1)
+    r = MSparse<RT> (m1 / m2.elem (0,0));
+  else if (m1_nr != m2_nr || m1_nc != m2_nc)
+    octave::err_nonconformant ("operator /", m1_nr, m1_nc, m2_nr, m2_nc);
+  else
+    {
+      if (do_mx_check (m1, mx_inline_all_finite<M1>))
+        {
+          /* Sparsity pattern is preserved. */
+          octave_idx_type m2_nz = m2.nnz ();
+          r = MSparse<RT> (m2_nr, m2_nc, m2_nz);
+          for (octave_idx_type j = 0, k = 0; j < m2_nc; j++)
+            {
+              octave_quit ();
+              for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++)
+                {
+                  octave_idx_type mri = m2.ridx (i);
+                  RT x = m1(mri, j) / m2.data (i);
+                  if (x != 0.0)
+                    {
+                      r.xdata (k) = x;
+                      r.xridx (k) = m2.ridx (i);
+                      k++;
+                    }
+                }
+              r.xcidx (j+1) = k;
+            }
+          r.maybe_compress (false);
+          return r;
+        }
+      else
+        r = MSparse<RT> (quotient (m1, MArray<M2> (m2.array_value ())));
+    }
+
+  return r;
+}
+
+#define SPARSE_MSM_BIN_OP(R, F, OP_FN, M1, M2)  \
+  R                                             \
+  F (const M1& m1, const M2& m2)                \
+  {                                             \
+    return OP_FN (m1, m2);                      \
   }
 
-#define SPARSE_MSM_BIN_OP_2(R, F, OP, M1, M2)                           \
-  R                                                                     \
-  F (const M1& m1, const M2& m2)                                        \
-  {                                                                     \
-    R r;                                                                \
-                                                                        \
-    octave_idx_type m1_nr = m1.rows ();                                 \
-    octave_idx_type m1_nc = m1.cols ();                                 \
-                                                                        \
-    octave_idx_type m2_nr = m2.rows ();                                 \
-    octave_idx_type m2_nc = m2.cols ();                                 \
-                                                                        \
-    if (m2_nr == 1 && m2_nc == 1)                                       \
-      r = R (m1 OP m2.elem (0,0));                                      \
-    else if (m1_nr != m2_nr || m1_nc != m2_nc)                          \
-      octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc);               \
-    else                                                                \
-      {                                                                 \
-        if (do_mx_check (m1, mx_inline_all_finite<M1::element_type>))   \
-          {                                                             \
-            /* Sparsity pattern is preserved. */                        \
-            octave_idx_type m2_nz = m2.nnz ();                          \
-            r = R (m2_nr, m2_nc, m2_nz);                                \
-            for (octave_idx_type j = 0, k = 0; j < m2_nc; j++)          \
-              {                                                         \
-                octave_quit ();                                         \
-                for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) \
-                  {                                                     \
-                    octave_idx_type mri = m2.ridx (i);                  \
-                    R::element_type x = m1(mri, j) OP m2.data (i);      \
-                    if (x != 0.0)                                       \
-                      {                                                 \
-                        r.xdata (k) = x;                                \
-                        r.xridx (k) = m2.ridx (i);                      \
-                        k++;                                            \
-                      }                                                 \
-                  }                                                     \
-                r.xcidx (j+1) = k;                                      \
-              }                                                         \
-            r.maybe_compress (false);                                   \
-            return r;                                                   \
-          }                                                             \
-        else                                                            \
-          r = R (F (m1, m2.matrix_value ()));                           \
-      }                                                                 \
-                                                                        \
-    return r;                                                           \
-  }
-
-#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2)              \
-  SPARSE_MSM_BIN_OP_1 (R1, operator +,  +, M1, M2)      \
-  SPARSE_MSM_BIN_OP_1 (R1, operator -,  -, M1, M2)      \
-  SPARSE_MSM_BIN_OP_2 (R2, product,     *, M1, M2)      \
-  SPARSE_MSM_BIN_OP_1 (R2, quotient,    /, M1, M2)
+#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2)                      \
+  SPARSE_MSM_BIN_OP (R1, operator +, msm_add_op, M1, M2)        \
+  SPARSE_MSM_BIN_OP (R1, operator -, msm_sub_op, M1, M2)        \
+  SPARSE_MSM_BIN_OP (R2, product,    msm_mul_op, M1, M2)        \
+  SPARSE_MSM_BIN_OP (R2, quotient,   msm_div_op, M1, M2)
 
 #define SPARSE_MSM_CMP_OP(F, OP, M1, M2)                                \
   SparseBoolMatrix                                                      \