Mercurial > jwe > octave
changeset 25668:e84d53ffcae5 jwe
tmp template commit
author | John W. Eaton <jwe@octave.org> |
---|---|
date | Tue, 08 Nov 2016 15:06:28 -0500 |
parents | 8b8832ce72b2 |
children | |
files | liboctave/array/MArray.h liboctave/operators/Sparse-op-defs.h |
diffstat | 2 files changed, 216 insertions(+), 75 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/array/MArray.h Mon Nov 07 11:11:26 2016 -0500 +++ b/liboctave/array/MArray.h Tue Nov 08 15:06:28 2016 -0500 @@ -54,6 +54,64 @@ template <typename T> MArray<T> quotient (const MArray<T>&, const MArray<T>&); template <typename T> MArray<T> product (const MArray<T>&, const MArray<T>&); +template <typename T1, typename T2, typename OP> +auto +mm_bin_op (const MArray<T1>& x, const MArray<T2>& y, OP op) -> MArray<decltype (op (T1 (), T2 ()))> +{ + typedef decltype (op (T1 (), T2 ())) RT; + + dim_vector dx = x.dims (); + dim_vector dy = y.dims (); + + if (dx == dy) + { + Array<R> r (dx); + op (r.numel (), r.fortran_vec (), x.data (), y.data ()); + return r; + } + else if (is_valid_bsxfun (opname, dx, dy)) + { + return do_bsxfun_op (x, y, op, op1, op2); + } + else + octave::err_nonconformant (opname, dx, dy); + + MArray<RT> r (x.dims ()); + + for (octave_idx_type i = 0; i < n; i++) + r[i] = op (x[i], y[i]); + + return r; +} + +template <typename MT, typename ST, typename OP> +auto +ms_bin_op (const MArray<MT>& x, const ST& y, OP op) -> MArray<decltype (op (MT (), ST ()))> +{ + typedef decltype (op (MT (), ST ())) RT; + + MArray<RT> r (x.dims ()); + + for (octave_idx_type i = 0; i < n; i++) + r[i] = op (x[i], y); + + return r; +} + +template <typename MT, typename ST, typename OP> +auto +sm_bin_op (const ST& x, const MArray<MT>& y, OP op) -> MArray<decltype (op (ST (), MT ()))> +{ + typedef decltype (op (MT (), ST ())) RT; + + MArray<RT> r (x.dims ()); + + for (octave_idx_type i = 0; i < n; i++) + r[i] = op (x, y[i]); + + return r; +} + //! Template for N-dimensional array classes with like-type math operators. template <typename T> class
--- a/liboctave/operators/Sparse-op-defs.h Mon Nov 07 11:11:26 2016 -0500 +++ b/liboctave/operators/Sparse-op-defs.h Tue Nov 08 15:06:28 2016 -0500 @@ -714,83 +714,166 @@ // matrix by sparse matrix operations. -#define SPARSE_MSM_BIN_OP_1(R, F, OP, M1, M2) \ - R \ - F (const M1& m1, const M2& m2) \ - { \ - R r; \ - \ - octave_idx_type m1_nr = m1.rows (); \ - octave_idx_type m1_nc = m1.cols (); \ - \ - octave_idx_type m2_nr = m2.rows (); \ - octave_idx_type m2_nc = m2.cols (); \ - \ - if (m2_nr == 1 && m2_nc == 1) \ - r = R (m1 OP m2.elem (0,0)); \ - else if (m1_nr != m2_nr || m1_nc != m2_nc) \ - octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ - else \ - { \ - r = R (F (m1, m2.matrix_value ())); \ - } \ - return r; \ +template <typename M1, typename M2> +auto +msm_add_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () + M2 ())> +{ + typedef decltype (M1 () + M2 ()) RT; + + MArray<RT> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MArray<RT> (m1 + m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator +", m1_nr, m1_nc, m2_nr, m2_nc); + else + r = m1 + MArray<M2> (m2.array_value ()); + + return r; +} + +template <typename M1, typename M2> +auto +msm_sub_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () - M2 ())> +{ + MArray<decltype (M1 () - M2 ())> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = m1 - m2.elem (0,0); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator -", m1_nr, m1_nc, m2_nr, m2_nc); + else + r = m1 - MArray<M2> (m2.array_value ()); + + return r; +} + +template <typename M1, typename M2> +auto +msm_mul_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () * M2 ())> +{ + typedef decltype (M1 () * M2 ()) RT; + + MSparse<RT> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MSparse<RT> (m1 * m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator *", m1_nr, m1_nc, m2_nr, m2_nc); + else + { + if (do_mx_check (m1, mx_inline_all_finite<M1>)) + { + /* Sparsity pattern is preserved. */ + octave_idx_type m2_nz = m2.nnz (); + r = MSparse<RT> (m2_nr, m2_nc, m2_nz); + for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) + { + octave_quit (); + for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) + { + octave_idx_type mri = m2.ridx (i); + RT x = m1(mri, j) * m2.data (i); + if (x != 0.0) + { + r.xdata (k) = x; + r.xridx (k) = m2.ridx (i); + k++; + } + } + r.xcidx (j+1) = k; + } + r.maybe_compress (false); + return r; + } + else + r = MSparse<RT> (product (m1, MArray<M2> (m2.array_value ()))); + } + + return r; +} + +template <typename M1, typename M2> +auto +msm_div_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () / M2 ())> +{ + typedef decltype (M1 () / M2 ()) RT; + + MSparse<RT> r; + + octave_idx_type m1_nr = m1.rows (); + octave_idx_type m1_nc = m1.cols (); + + octave_idx_type m2_nr = m2.rows (); + octave_idx_type m2_nc = m2.cols (); + + if (m2_nr == 1 && m2_nc == 1) + r = MSparse<RT> (m1 / m2.elem (0,0)); + else if (m1_nr != m2_nr || m1_nc != m2_nc) + octave::err_nonconformant ("operator /", m1_nr, m1_nc, m2_nr, m2_nc); + else + { + if (do_mx_check (m1, mx_inline_all_finite<M1>)) + { + /* Sparsity pattern is preserved. */ + octave_idx_type m2_nz = m2.nnz (); + r = MSparse<RT> (m2_nr, m2_nc, m2_nz); + for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) + { + octave_quit (); + for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) + { + octave_idx_type mri = m2.ridx (i); + RT x = m1(mri, j) / m2.data (i); + if (x != 0.0) + { + r.xdata (k) = x; + r.xridx (k) = m2.ridx (i); + k++; + } + } + r.xcidx (j+1) = k; + } + r.maybe_compress (false); + return r; + } + else + r = MSparse<RT> (quotient (m1, MArray<M2> (m2.array_value ()))); + } + + return r; +} + +#define SPARSE_MSM_BIN_OP(R, F, OP_FN, M1, M2) \ + R \ + F (const M1& m1, const M2& m2) \ + { \ + return OP_FN (m1, m2); \ } -#define SPARSE_MSM_BIN_OP_2(R, F, OP, M1, M2) \ - R \ - F (const M1& m1, const M2& m2) \ - { \ - R r; \ - \ - octave_idx_type m1_nr = m1.rows (); \ - octave_idx_type m1_nc = m1.cols (); \ - \ - octave_idx_type m2_nr = m2.rows (); \ - octave_idx_type m2_nc = m2.cols (); \ - \ - if (m2_nr == 1 && m2_nc == 1) \ - r = R (m1 OP m2.elem (0,0)); \ - else if (m1_nr != m2_nr || m1_nc != m2_nc) \ - octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ - else \ - { \ - if (do_mx_check (m1, mx_inline_all_finite<M1::element_type>)) \ - { \ - /* Sparsity pattern is preserved. */ \ - octave_idx_type m2_nz = m2.nnz (); \ - r = R (m2_nr, m2_nc, m2_nz); \ - for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) \ - { \ - octave_quit (); \ - for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) \ - { \ - octave_idx_type mri = m2.ridx (i); \ - R::element_type x = m1(mri, j) OP m2.data (i); \ - if (x != 0.0) \ - { \ - r.xdata (k) = x; \ - r.xridx (k) = m2.ridx (i); \ - k++; \ - } \ - } \ - r.xcidx (j+1) = k; \ - } \ - r.maybe_compress (false); \ - return r; \ - } \ - else \ - r = R (F (m1, m2.matrix_value ())); \ - } \ - \ - return r; \ - } - -#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R1, operator +, +, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R1, operator -, -, M1, M2) \ - SPARSE_MSM_BIN_OP_2 (R2, product, *, M1, M2) \ - SPARSE_MSM_BIN_OP_1 (R2, quotient, /, M1, M2) +#define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \ + SPARSE_MSM_BIN_OP (R1, operator +, msm_add_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R1, operator -, msm_sub_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R2, product, msm_mul_op, M1, M2) \ + SPARSE_MSM_BIN_OP (R2, quotient, msm_div_op, M1, M2) #define SPARSE_MSM_CMP_OP(F, OP, M1, M2) \ SparseBoolMatrix \