Mercurial > jwe > octave
diff liboctave/operators/Sparse-op-defs.h @ 25668:e84d53ffcae5 jwe
tmp template commit
author | John W. Eaton <jwe@octave.org> |
---|---|
date | Tue, 08 Nov 2016 15:06:28 -0500 |
parents | 8b8832ce72b2 |
children |
line wrap: on
line diff
--- a/liboctave/operators/Sparse-op-defs.h Mon Nov 07 11:11:26 2016 -0500 +++ b/liboctave/operators/Sparse-op-defs.h Tue Nov 08 15:06:28 2016 -0500 @@ -714,83 +714,166 @@ // matrix by sparse matrix operations. -#define SPARSE_MSM_BIN_OP_1(R, F, OP, M1, M2) \ - R \ - F (const M1& m1, const M2& m2) \ - { \ - R r; \ - \ - octave_idx_type m1_nr = m1.rows (); \ - octave_idx_type m1_nc = m1.cols (); \ - \ - octave_idx_type m2_nr = m2.rows (); \ - octave_idx_type m2_nc = m2.cols (); \ - \ - if (m2_nr == 1 && m2_nc == 1) \ - r = R (m1 OP m2.elem (0,0)); \ - else if (m1_nr != m2_nr || m1_nc != m2_nc) \ - octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ - else \ - { \ - r = R (F (m1, m2.matrix_value ())); \ - } \ - return r; \ +template <typename M1, typename M2> +auto +msm_add_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () + M2 ())> +{ + typedef decltype (M1 () + M2 ()) RT; + + MArray<RT> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MArray<RT> (m1 + m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator +", m1_nr, m1_nc, m2_nr, m2_nc); + else + r = m1 + MArray<M2> (m2.array_value ()); + + return r; +} + +template <typename M1, typename M2> +auto +msm_sub_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () - M2 ())> +{ + MArray<decltype (M1 () - M2 ())> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = m1 - m2.elem (0,0); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator -", m1_nr, m1_nc, m2_nr, m2_nc); + else + r = m1 - MArray<M2> (m2.array_value ()); + + return r; +} + +template <typename M1, typename M2> +auto +msm_mul_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () * M2 ())> +{ + typedef decltype (M1 () * M2 ()) RT; + + MSparse<RT> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MSparse<RT> (m1 * m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator *", m1_nr, m1_nc, m2_nr, m2_nc); + else + { + if (do_mx_check (m1, mx_inline_all_finite<M1>)) + { + /* Sparsity pattern is preserved. */ + octave_idx_type m2_nz = m2.nnz (); + r = MSparse<RT> (m2_nr, m2_nc, m2_nz); + for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) + { + octave_quit (); + for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) + { + octave_idx_type mri = m2.ridx (i); + RT x = m1(mri, j) * m2.data (i); + if (x != 0.0) + { + r.xdata (k) = x; + r.xridx (k) = m2.ridx (i); + k++; + } + } + r.xcidx (j+1) = k; + } + r.maybe_compress (false); + return r; + } + else + r = MSparse<RT> (product (m1, MArray<M2> (m2.array_value ()))); + } + + return r; +} + +template <typename M1, typename M2> +auto +msm_div_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () / M2 ())> +{ + typedef decltype (M1 () / M2 ()) RT; + + MSparse<RT> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MSparse<RT> (m1 / m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator /", m1_nr, m1_nc, m2_nr, m2_nc); + else + { + if (do_mx_check (m1, mx_inline_all_finite<M1>)) + { + /* Sparsity pattern is preserved. */ + octave_idx_type m2_nz = m2.nnz (); + r = MSparse<RT> (m2_nr, m2_nc, m2_nz); + for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) + { + octave_quit (); + for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) + { + octave_idx_type mri = m2.ridx (i); + RT x = m1(mri, j) / m2.data (i); + if (x != 0.0) + { + r.xdata (k) = x; + r.xridx (k) = m2.ridx (i); + k++; + } + } + r.xcidx (j+1) = k; + } + r.maybe_compress (false); + return r; + } + else + r = MSparse<RT> (quotient (m1, MArray<M2> (m2.array_value ()))); + } + + return r; +} + +#define SPARSE_MSM_BIN_OP(R, F, OP_FN, M1, M2) \ + R \ + F (const M1& m1, const M2& m2) \ + { \ + return OP_FN (m1, m2); \ } -#define SPARSE_MSM_BIN_OP_2(R, F, OP, M1, M2) \ - R \ - F (const M1& m1, const M2& m2) \ - { \ - R r; \ - \ - octave_idx_type m1_nr = m1.rows (); \ - octave_idx_type m1_nc = m1.cols (); \ - \ - octave_idx_type m2_nr = m2.rows (); \ - octave_idx_type m2_nc = m2.cols (); \ - \ - if (m2_nr == 1 && m2_nc == 1) \ - r = R (m1 OP m2.elem (0,0)); \ - else if (m1_nr != m2_nr || m1_nc != m2_nc) \ - octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ - else \ - { \ - if (do_mx_check (m1, mx_inline_all_finite<M1::element_type>)) \ - { \ - /* Sparsity pattern is preserved. */ \ - octave_idx_type m2_nz = m2.nnz (); \ - r = R (m2_nr, m2_nc, m2_nz); \ - for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) \ - { \ - octave_quit (); \ - for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) \ - { \ - octave_idx_type mri = m2.ridx (i); \ - R::element_type x = m1(mri, j) OP m2.data (i); \ - if (x != 0.0) \ - { \ - r.xdata (k) = x; \ - r.xridx (k) = m2.ridx (i); \ - k++; \ - } \ - } \ - r.xcidx (j+1) = k; \ - } \ - r.maybe_compress (false); \ - return r; \ - } \ - else \ - r = R (F (m1, m2.matrix_value ())); \ - } \ - \ - return r; \ - } - -#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R1, operator +, +, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R1, operator -, -, M1, M2) \ - SPARSE_MSM_BIN_OP_2 (R2, product, *, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R2, quotient, /, M1, M2) +#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \ + SPARSE_MSM_BIN_OP (R1, operator +, msm_add_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R1, operator -, msm_sub_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R2, product, msm_mul_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R2, quotient, msm_div_op, M1, M2) #define SPARSE_MSM_CMP_OP(F, OP, M1, M2) \ SparseBoolMatrix \