# HG changeset patch # User John W. Eaton # Date 1478635588 18000 # Node ID e84d53ffcae50224d51a471e7038c371c39d1d12 # Parent 8b8832ce72b2ce41adebe0eec8c29ebe45e2f673 tmp template commit diff -r 8b8832ce72b2 -r e84d53ffcae5 liboctave/array/MArray.h --- a/liboctave/array/MArray.h Mon Nov 07 11:11:26 2016 -0500 +++ b/liboctave/array/MArray.h Tue Nov 08 15:06:28 2016 -0500 @@ -54,6 +54,64 @@ template MArray quotient (const MArray&, const MArray&); template MArray product (const MArray&, const MArray&); +template +auto +mm_bin_op (const MArray& x, const MArray& y, OP op) -> MArray +{ + typedef decltype (op (T1 (), T2 ())) RT; + + dim_vector dx = x.dims (); + dim_vector dy = y.dims (); + + if (dx == dy) + { + Array r (dx); + op (r.numel (), r.fortran_vec (), x.data (), y.data ()); + return r; + } + else if (is_valid_bsxfun (opname, dx, dy)) + { + return do_bsxfun_op (x, y, op, op1, op2); + } + else + octave::err_nonconformant (opname, dx, dy); + + MArray r (x.dims ()); + + for (octave_idx_type i = 0; i < n; i++) + r[i] = op (x[i], y[i]); + + return r; +} + +template +auto +ms_bin_op (const MArray& x, const ST& y, OP op) -> MArray +{ + typedef decltype (op (MT (), ST ())) RT; + + MArray r (x.dims ()); + + for (octave_idx_type i = 0; i < n; i++) + r[i] = op (x[i], y); + + return r; +} + +template +auto +sm_bin_op (const ST& x, const MArray& y, OP op) -> MArray +{ + typedef decltype (op (MT (), ST ())) RT; + + MArray r (x.dims ()); + + for (octave_idx_type i = 0; i < n; i++) + r[i] = op (x, y[i]); + + return r; +} + //! Template for N-dimensional array classes with like-type math operators. template class diff -r 8b8832ce72b2 -r e84d53ffcae5 liboctave/operators/Sparse-op-defs.h --- a/liboctave/operators/Sparse-op-defs.h Mon Nov 07 11:11:26 2016 -0500 +++ b/liboctave/operators/Sparse-op-defs.h Tue Nov 08 15:06:28 2016 -0500 @@ -714,83 +714,166 @@ // matrix by sparse matrix operations. -#define SPARSE_MSM_BIN_OP_1(R, F, OP, M1, M2) \ - R \ - F (const M1& m1, const M2& m2) \ - { \ - R r; \ - \ - octave_idx_type m1_nr = m1.rows (); \ - octave_idx_type m1_nc = m1.cols (); \ - \ - octave_idx_type m2_nr = m2.rows (); \ - octave_idx_type m2_nc = m2.cols (); \ - \ - if (m2_nr == 1 && m2_nc == 1) \ - r = R (m1 OP m2.elem (0,0)); \ - else if (m1_nr != m2_nr || m1_nc != m2_nc) \ - octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ - else \ - { \ - r = R (F (m1, m2.matrix_value ())); \ - } \ - return r; \ +template +auto +msm_add_op (const MArray& m1, const MSparse& m2) -> MArray +{ + typedef decltype (M1 () + M2 ()) RT; + + MArray r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MArray (m1 + m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator +", m1_nr, m1_nc, m2_nr, m2_nc); + else + r = m1 + MArray (m2.array_value ()); + + return r; +} + +template +auto +msm_sub_op (const MArray& m1, const MSparse& m2) -> MArray +{ + MArray r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = m1 - m2.elem (0,0); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator -", m1_nr, m1_nc, m2_nr, m2_nc); + else + r = m1 - MArray (m2.array_value ()); + + return r; +} + +template +auto +msm_mul_op (const MArray& m1, const MSparse& m2) -> MSparse +{ + typedef decltype (M1 () * M2 ()) RT; + + MSparse r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MSparse (m1 * m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator *", m1_nr, m1_nc, m2_nr, m2_nc); + else + { + if (do_mx_check (m1, mx_inline_all_finite)) + { + /* Sparsity pattern is preserved. */ + octave_idx_type m2_nz = m2.nnz (); + r = MSparse (m2_nr, m2_nc, m2_nz); + for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) + { + octave_quit (); + for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) + { + octave_idx_type mri = m2.ridx (i); + RT x = m1(mri, j) * m2.data (i); + if (x != 0.0) + { + r.xdata (k) = x; + r.xridx (k) = m2.ridx (i); + k++; + } + } + r.xcidx (j+1) = k; + } + r.maybe_compress (false); + return r; + } + else + r = MSparse (product (m1, MArray (m2.array_value ()))); + } + + return r; +} + +template +auto +msm_div_op (const MArray& m1, const MSparse& m2) -> MSparse +{ + typedef decltype (M1 () / M2 ()) RT; + + MSparse r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MSparse (m1 / m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator /", m1_nr, m1_nc, m2_nr, m2_nc); + else + { + if (do_mx_check (m1, mx_inline_all_finite)) + { + /* Sparsity pattern is preserved. */ + octave_idx_type m2_nz = m2.nnz (); + r = MSparse (m2_nr, m2_nc, m2_nz); + for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) + { + octave_quit (); + for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) + { + octave_idx_type mri = m2.ridx (i); + RT x = m1(mri, j) / m2.data (i); + if (x != 0.0) + { + r.xdata (k) = x; + r.xridx (k) = m2.ridx (i); + k++; + } + } + r.xcidx (j+1) = k; + } + r.maybe_compress (false); + return r; + } + else + r = MSparse (quotient (m1, MArray (m2.array_value ()))); + } + + return r; +} + +#define SPARSE_MSM_BIN_OP(R, F, OP_FN, M1, M2) \ + R \ + F (const M1& m1, const M2& m2) \ + { \ + return OP_FN (m1, m2); \ } -#define SPARSE_MSM_BIN_OP_2(R, F, OP, M1, M2) \ - R \ - F (const M1& m1, const M2& m2) \ - { \ - R r; \ - \ - octave_idx_type m1_nr = m1.rows (); \ - octave_idx_type m1_nc = m1.cols (); \ - \ - octave_idx_type m2_nr = m2.rows (); \ - octave_idx_type m2_nc = m2.cols (); \ - \ - if (m2_nr == 1 && m2_nc == 1) \ - r = R (m1 OP m2.elem (0,0)); \ - else if (m1_nr != m2_nr || m1_nc != m2_nc) \ - octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ - else \ - { \ - if (do_mx_check (m1, mx_inline_all_finite)) \ - { \ - /* Sparsity pattern is preserved. */ \ - octave_idx_type m2_nz = m2.nnz (); \ - r = R (m2_nr, m2_nc, m2_nz); \ - for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) \ - { \ - octave_quit (); \ - for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) \ - { \ - octave_idx_type mri = m2.ridx (i); \ - R::element_type x = m1(mri, j) OP m2.data (i); \ - if (x != 0.0) \ - { \ - r.xdata (k) = x; \ - r.xridx (k) = m2.ridx (i); \ - k++; \ - } \ - } \ - r.xcidx (j+1) = k; \ - } \ - r.maybe_compress (false); \ - return r; \ - } \ - else \ - r = R (F (m1, m2.matrix_value ())); \ - } \ - \ - return r; \ - } - -#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R1, operator +, +, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R1, operator -, -, M1, M2) \ - SPARSE_MSM_BIN_OP_2 (R2, product, *, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R2, quotient, /, M1, M2) +#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \ + SPARSE_MSM_BIN_OP (R1, operator +, msm_add_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R1, operator -, msm_sub_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R2, product, msm_mul_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R2, quotient, msm_div_op, M1, M2) #define SPARSE_MSM_CMP_OP(F, OP, M1, M2) \ SparseBoolMatrix \