Mercurial > jwe > octave
view liboctave/operators/Sparse-op-defs.h @ 25668:e84d53ffcae5 jwe
tmp template commit
author | John W. Eaton <jwe@octave.org> |
---|---|
date | Tue, 08 Nov 2016 15:06:28 -0500 |
parents | 8b8832ce72b2 |
children |
line wrap: on
line source
/* Copyright (C) 2004-2018 David Bateman Copyright (C) 1998-2004 Andy Adler Copyright (C) 2008 Jaroslav Hajek This file is part of Octave. Octave is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Octave is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Octave; see the file COPYING. If not, see <https://www.gnu.org/licenses/>. */ #if ! defined (octave_Sparse_op_defs_h) #define octave_Sparse_op_defs_h 1 #include "octave-config.h" #include "Array-util.h" #include "lo-array-errwarn.h" #include "mx-inlines.cc" #include "oct-locbuf.h" // sparse matrix by scalar operations. #define SPARSE_SMS_BIN_OP_1(R, F, OP, M, S) \ R \ F (const M& m, const S& s) \ { \ return add_or_sub (m, s, OP<M::element_type, S>); \ } #define SPARSE_SMS_BIN_OP_2(R, F, OP, M, S) \ R \ F (const M& m, const S& s) \ { \ return mul_or_div (m, s, OP<M::element_type, S>); \ } #define SPARSE_SMS_BIN_OPS(R1, R2, M, S) \ SPARSE_SMS_BIN_OP_1 (R1, operator +, octave::math::add, M, S) \ SPARSE_SMS_BIN_OP_1 (R1, operator -, octave::math::sub, M, S) \ SPARSE_SMS_BIN_OP_2 (R2, operator *, octave::math::mul, M, S) \ SPARSE_SMS_BIN_OP_2 (R2, operator /, octave::math::div, M, S) template <typename MT, typename ST, typename OP> Sparse<bool> sms_cmp_op (const Sparse<MT>& m, const ST& s, OP op) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); Sparse<bool> r; MT m_zero = MT (); if (op (m_zero, s)) { r = SparseBoolMatrix (nr, nc, true); for (octave_idx_type j = 0; j < nc; j++) for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (! op (m.data (i), s)) r.data (m.ridx (i) + j * nr) = false; r.maybe_compress (true); } else { r = SparseBoolMatrix (nr, nc, m.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < nc; j++) { for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (op (m.data (i), s)) { r.ridx (nel) = m.ridx (i); r.data (nel++) = true; } r.cidx (j + 1) = nel; } r.maybe_compress (false); } return r; } #define SPARSE_SMS_CMP_OP(F, OP, M, S) \ SparseBoolMatrix \ F (const M& m, const S& s) \ { \ return sms_cmp_op (m, s, OP<M::element_type, S>); \ } #define SPARSE_SMS_CMP_OPS(M, S) \ SPARSE_SMS_CMP_OP (mx_el_lt, octave::math::lt, M, S) \ SPARSE_SMS_CMP_OP (mx_el_le, octave::math::le, M, S) \ SPARSE_SMS_CMP_OP (mx_el_ge, octave::math::ge, M, S) \ SPARSE_SMS_CMP_OP (mx_el_gt, octave::math::gt, M, S) \ SPARSE_SMS_CMP_OP (mx_el_eq, octave::math::eq, M, S) \ SPARSE_SMS_CMP_OP (mx_el_ne, octave::math::ne, M, S) #define SPARSE_SMS_EQNE_OPS(M, S) \ SPARSE_SMS_CMP_OP (mx_el_eq, octave::math::eq, M, S) \ SPARSE_SMS_CMP_OP (mx_el_ne, octave::math::ne, M, S) template <typename MT, typename ST> Sparse<bool> mx_el_or (const Sparse<MT>& m, const ST& s) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); Sparse<bool> r; MT lhs_zero = MT (); ST rhs_zero = ST (); if (nr > 0 && nc > 0) { if (s != rhs_zero) r = Sparse<bool> (nr, nc, true); else { r = SparseBoolMatrix (nr, nc, m.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < nc; j++) { for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (m.data (i) != lhs_zero) { r.ridx (nel) = m.ridx (i); r.data (nel++) = true; } r.cidx (j + 1) = nel; } r.maybe_compress (false); } } return r; } #define SPARSE_SMS_BOOL_OR_OP(M, S) \ SparseBoolMatrix \ mx_el_or (const M& m, const S& s) \ { \ return mx_el_or<M::element_type, S> (m, s); \ } template <typename MT, typename ST> Sparse<bool> mx_el_and (const Sparse<MT>& m, const ST& s) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); Sparse<bool> r; MT lhs_zero = MT (); ST rhs_zero = ST (); if (nr > 0 && nc > 0) { if (s != rhs_zero) { r = Sparse<bool> (nr, nc, m.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < nc; j++) { for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (m.data (i) != lhs_zero) { r.ridx (nel) = m.ridx (i); r.data (nel++) = true; } r.cidx (j + 1) = nel; } r.maybe_compress (false); } else r = Sparse<bool> (nr, nc); } return r; } #define SPARSE_SMS_BOOL_AND_OP(M, S) \ SparseBoolMatrix \ mx_el_and (const M& m, const S& s) \ { \ return mx_el_and<M::element_type, S> (m, s); \ } #define SPARSE_SMS_BOOL_OPS(M, S) \ SPARSE_SMS_BOOL_AND_OP (M, S) \ SPARSE_SMS_BOOL_OR_OP (M, S) // scalar by sparse matrix operations. #define SPARSE_SSM_BIN_OP_1(R, F, OP, S, M) \ R \ F (const S& s, const M& m) \ { \ return add_or_sub (s, m, OP<S, M::element_type>); \ } #define SPARSE_SSM_BIN_OP_2(R, F, OP, S, M) \ R \ F (const S& s, const M& m) \ { \ return mul_or_div (s, m, OP<S, M::element_type>); \ } #define SPARSE_SSM_BIN_OPS(R1, R2, S, M) \ SPARSE_SSM_BIN_OP_1 (R1, operator +, octave::math::add, S, M) \ SPARSE_SSM_BIN_OP_1 (R1, operator -, octave::math::sub, S, M) \ SPARSE_SSM_BIN_OP_2 (R2, operator *, octave::math::mul, S, M) \ SPARSE_SSM_BIN_OP_2 (R2, operator /, octave::math::div, S, M) template <typename ST, typename MT, typename OP> Sparse<bool> ssm_cmp_op (const ST& s, const Sparse<MT>& m, OP op) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); Sparse<bool> r; MT m_zero = MT (); if (op (s, m_zero)) { r = Sparse<bool> (nr, nc, true); for (octave_idx_type j = 0; j < nc; j++) for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (! op (s, m.data (i))) r.data (m.ridx (i) + j * nr) = false; r.maybe_compress (true); } else { r = Sparse<bool> (nr, nc, m.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < nc; j++) { for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (op (s, m.data (i))) { r.ridx (nel) = m.ridx (i); r.data (nel++) = true; } r.cidx (j + 1) = nel; } r.maybe_compress (false); } return r; } #define SPARSE_SSM_CMP_OP(F, OP, S, M) \ SparseBoolMatrix \ F (const S& s, const M& m) \ { \ return ssm_cmp_op (s, m, OP<S, M::element_type>); \ } \ #define SPARSE_SSM_CMP_OPS(S, M) \ SPARSE_SSM_CMP_OP (mx_el_lt, octave::math::lt, S, M) \ SPARSE_SSM_CMP_OP (mx_el_le, octave::math::le, S, M) \ SPARSE_SSM_CMP_OP (mx_el_ge, octave::math::ge, S, M) \ SPARSE_SSM_CMP_OP (mx_el_gt, octave::math::gt, S, M) \ SPARSE_SSM_CMP_OP (mx_el_eq, octave::math::eq, S, M) \ SPARSE_SSM_CMP_OP (mx_el_ne, octave::math::ne, S, M) #define SPARSE_SSM_EQNE_OPS(S, M) \ SPARSE_SSM_CMP_OP (mx_el_eq, octave::math::eq, S, M) \ SPARSE_SSM_CMP_OP (mx_el_ne, octave::math::ne, S, M) template <typename ST, typename MT> Sparse<bool> mx_el_or (const ST& s, const Sparse<MT>& m) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); Sparse<bool> r; ST lhs_zero = ST (); MT rhs_zero = MT (); if (nr > 0 && nc > 0) { if (s != lhs_zero) r = Sparse<bool> (nr, nc, true); else { r = Sparse<bool> (nr, nc, m.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < nc; j++) { for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (m.data (i) != rhs_zero) { r.ridx (nel) = m.ridx (i); r.data (nel++) = true; } r.cidx (j + 1) = nel; } r.maybe_compress (false); } } return r; } #define SPARSE_SSM_BOOL_OR_OP(S, M) \ SparseBoolMatrix \ mx_el_or (const S& s, const M& m) \ { \ return mx_el_or<S, M::element_type> (s, m); \ } template <typename ST, typename MT> Sparse<bool> mx_el_and (const ST& s, const Sparse<MT>& m) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); Sparse<bool> r; ST lhs_zero = ST (); MT rhs_zero = MT (); if (nr > 0 && nc > 0) { if (s != lhs_zero) { r = Sparse<bool> (nr, nc, m.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < nc; j++) { for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) if (m.data (i) != rhs_zero) { r.ridx (nel) = m.ridx (i); r.data (nel++) = true; } r.cidx (j + 1) = nel; } r.maybe_compress (false); } else r = Sparse<bool> (nr, nc); } return r; } #define SPARSE_SSM_BOOL_AND_OP(S, M) \ SparseBoolMatrix \ mx_el_and (const S& s, const M& m) \ { \ return mx_el_and<S, M::element_type> (s, m); \ } #define SPARSE_SSM_BOOL_OPS(S, M) \ SPARSE_SSM_BOOL_AND_OP (S, M) \ SPARSE_SSM_BOOL_OR_OP (S, M) // sparse matrix by sparse matrix operations. #define SPARSE_SMSM_BIN_OP_1(R, F, OP, M1, M2, NEGATE) \ R \ F (const M1& m1, const M2& m2) \ { \ return add_or_sub (m1, m2, OP<M1::element_type, M2::element_type>, #F, NEGATE); \ } #define SPARSE_SMSM_BIN_OP_2(R, F, M1, M2) \ R \ F (const M1& m1, const M2& m2) \ { \ return F<M1::element_type, M2::element_type> (m1, m2); \ } #define SPARSE_SMSM_BIN_OPS(R1, R2, M1, M2) \ SPARSE_SMSM_BIN_OP_1 (R1, operator +, octave::math::add, M1, M2, false) \ SPARSE_SMSM_BIN_OP_1 (R1, operator -, octave::math::sub, M1, M2, true) \ SPARSE_SMSM_BIN_OP_2 (R1, product, M1, M2) \ SPARSE_SMSM_BIN_OP_2 (R1, quotient, M1, M2) template <typename M1, typename M2, typename OP> Sparse<bool> smsm_cmp_op (const Sparse<M1>& m1, const Sparse<M2>& m2, OP op, const char *op_name) { Sparse<bool> r; octave_idx_type m1_nr = m1.rows (); octave_idx_type m1_nc = m1.cols (); octave_idx_type m2_nr = m2.rows (); octave_idx_type m2_nc = m2.cols (); M1 Z1 = M1 (); M2 Z2 = M2 (); if (m1_nr == 1 && m1_nc == 1) return ssm_cmp_op (m1.elem (0, 0), m2, op); else if (m2_nr == 1 && m2_nc == 1) return sms_cmp_op (m1, m2.elem (0, 0), op); else if (m1_nr == m2_nr && m1_nc == m2_nc) { if (m1_nr != 0 || m1_nc != 0) { if (op (Z1, Z2)) { r = Sparse<bool> (m1_nr, m1_nc, true); for (octave_idx_type j = 0; j < m1_nc; j++) { octave_idx_type i1 = m1.cidx (j); octave_idx_type e1 = m1.cidx (j+1); octave_idx_type i2 = m2.cidx (j); octave_idx_type e2 = m2.cidx (j+1); while (i1 < e1 || i2 < e2) { if (i1 == e1 || (i2 < e2 && m1.ridx (i1) > m2.ridx (i2))) { if (! op (Z1, m2.data (i2))) r.data (m2.ridx (i2) + j * m1_nr) = false; i2++; } else if (i2 == e2 || m1.ridx (i1) < m2.ridx (i2)) { if (! op (m1.data (i1), Z2)) r.data (m1.ridx (i1) + j * m1_nr) = false; i1++; } else { if (! op (m1.data (i1), m2.data (i2))) r.data (m1.ridx (i1) + j * m1_nr) = false; i1++; i2++; } } } r.maybe_compress (true); } else { r = Sparse<bool> (m1_nr, m1_nc, m1.nnz () + m2.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < m1_nc; j++) { octave_idx_type i1 = m1.cidx (j); octave_idx_type e1 = m1.cidx (j+1); octave_idx_type i2 = m2.cidx (j); octave_idx_type e2 = m2.cidx (j+1); while (i1 < e1 || i2 < e2) { if (i1 == e1 || (i2 < e2 && m1.ridx (i1) > m2.ridx (i2))) { if (op (Z1, m2.data (i2))) { r.ridx (nel) = m2.ridx (i2); r.data (nel++) = true; } i2++; } else if (i2 == e2 || m1.ridx (i1) < m2.ridx (i2)) { if (op (m1.data (i1), Z2)) { r.ridx (nel) = m1.ridx (i1); r.data (nel++) = true; } i1++; } else { if (op (m1.data (i1), m2.data (i2))) { r.ridx (nel) = m1.ridx (i1); r.data (nel++) = true; } i1++; i2++; } } r.cidx (j + 1) = nel; } r.maybe_compress (false); } } } else { if ((m1_nr != 0 || m1_nc != 0) && (m2_nr != 0 || m2_nc != 0)) octave::err_nonconformant (op_name, m1_nr, m1_nc, m2_nr, m2_nc); } return r; } #define SPARSE_SMSM_CMP_OP(F, OP, M1, M2) \ SparseBoolMatrix \ F (const M1& m1, const M2& m2) \ { \ return smsm_cmp_op (m1, m2, OP<M1::element_type, M2::element_type>, #F); \ } #define SPARSE_SMSM_CMP_OPS(M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_lt, octave::math::lt, M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_le, octave::math::le, M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_ge, octave::math::ge, M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_gt, octave::math::gt, M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_eq, octave::math::eq, M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_ne, octave::math::ne, M1, M2) #define SPARSE_SMSM_EQNE_OPS(M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_eq, octave::math::eq, M1, M2) \ SPARSE_SMSM_CMP_OP (mx_el_ne, octave::math::ne, M1, M2) template <typename M1, typename M2> Sparse<bool> mx_el_and (const Sparse<M1>& m1, const Sparse<M2>& m2) { Sparse<bool> r; octave_idx_type m1_nr = m1.rows (); octave_idx_type m1_nc = m1.cols (); octave_idx_type m2_nr = m2.rows (); octave_idx_type m2_nc = m2.cols (); M1 lhs_zero = M1 (); M2 rhs_zero = M2 (); if (m1_nr == 1 && m1_nc == 1) return mx_el_and (m1.elem (0,0), m2); else if (m2_nr == 1 && m2_nc == 1) return mx_el_and (m1, m2.elem (0,0)); else if (m1_nr == m2_nr && m1_nc == m2_nc) { if (m1_nr != 0 || m1_nc != 0) { r = Sparse<bool> (m1_nr, m1_nc, m1.nnz () + m2.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < m1_nc; j++) { octave_idx_type i1 = m1.cidx (j); octave_idx_type e1 = m1.cidx (j+1); octave_idx_type i2 = m2.cidx (j); octave_idx_type e2 = m2.cidx (j+1); while (i1 < e1 || i2 < e2) { if (i1 == e1 || (i2 < e2 && m1.ridx (i1) > m2.ridx (i2))) i2++; else if (i2 == e2 || m1.ridx (i1) < m2.ridx (i2)) i1++; else { if (m1.data (i1) != lhs_zero && m2.data (i2) != rhs_zero) { r.ridx (nel) = m1.ridx (i1); r.data (nel++) = true; } i1++; i2++; } } r.cidx (j + 1) = nel; } r.maybe_compress (false); } } else { if ((m1_nr != 0 || m1_nc != 0) && (m2_nr != 0 || m2_nc != 0)) octave::err_nonconformant ("mx_el_and_", m1_nr, m1_nc, m2_nr, m2_nc); } return r; } #define SPARSE_SMSM_BOOL_AND_OP(M1, M2) \ SparseBoolMatrix \ mx_el_and (const M1& m1, const M2& m2) \ { \ return mx_el_and<M1::element_type, M2::element_type> (m1, m2); \ } template <typename M1, typename M2> Sparse<bool> mx_el_or (const Sparse<M1>& m1, const Sparse<M2>& m2) { Sparse<bool> r; octave_idx_type m1_nr = m1.rows (); octave_idx_type m1_nc = m1.cols (); octave_idx_type m2_nr = m2.rows (); octave_idx_type m2_nc = m2.cols (); M1 lhs_zero = M1 (); M2 rhs_zero = M2 (); if (m1_nr == 1 && m1_nc == 1) return mx_el_or (m1.elem (0,0), m2); else if (m2_nr == 1 && m2_nc == 1) return mx_el_or (m1, m2.elem (0,0)); else if (m1_nr == m2_nr && m1_nc == m2_nc) { if (m1_nr != 0 || m1_nc != 0) { r = Sparse<bool> (m1_nr, m1_nc, m1.nnz () + m2.nnz ()); r.cidx (0) = static_cast<octave_idx_type> (0); octave_idx_type nel = 0; for (octave_idx_type j = 0; j < m1_nc; j++) { octave_idx_type i1 = m1.cidx (j); octave_idx_type e1 = m1.cidx (j+1); octave_idx_type i2 = m2.cidx (j); octave_idx_type e2 = m2.cidx (j+1); while (i1 < e1 || i2 < e2) { if (i1 == e1 || (i2 < e2 && m1.ridx (i1) > m2.ridx (i2))) { if (m2.data (i2) != rhs_zero) { r.ridx (nel) = m2.ridx (i2); r.data (nel++) = true; } i2++; } else if (i2 == e2 || m1.ridx (i1) < m2.ridx (i2)) { if (m1.data (i1) != lhs_zero) { r.ridx (nel) = m1.ridx (i1); r.data (nel++) = true; } i1++; } else { if (m1.data (i1) != lhs_zero || m2.data (i2) != rhs_zero) { r.ridx (nel) = m1.ridx (i1); r.data (nel++) = true; } i1++; i2++; } } r.cidx (j + 1) = nel; } r.maybe_compress (false); } } else { if ((m1_nr != 0 || m1_nc != 0) && (m2_nr != 0 || m2_nc != 0)) octave::err_nonconformant ("mx_el_or", m1_nr, m1_nc, m2_nr, m2_nc); } return r; } #define SPARSE_SMSM_BOOL_OR_OP(M1, M2) \ SparseBoolMatrix \ mx_el_or (const M1& m1, const M2& m2) \ { \ return mx_el_or<M1::element_type, M2::element_type> (m1, m2); \ } #define SPARSE_SMSM_BOOL_OPS(M1, M2) \ SPARSE_SMSM_BOOL_AND_OP (M1, M2) \ SPARSE_SMSM_BOOL_OR_OP (M1, M2) // matrix by sparse matrix operations. template <typename M1, typename M2> auto msm_add_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () + M2 ())> { typedef decltype (M1 () + M2 ()) RT; MArray<RT> r; octave_idx_type m1_nr = m1.rows (); octave_idx_type m1_nc = m1.cols (); octave_idx_type m2_nr = m2.rows (); octave_idx_type m2_nc = m2.cols (); if (m2_nr == 1 && m2_nc == 1) r = MArray<RT> (m1 + m2.elem (0,0)); else if (m1_nr != m2_nr || m1_nc != m2_nc) octave::err_nonconformant ("operator +", m1_nr, m1_nc, m2_nr, m2_nc); else r = m1 + MArray<M2> (m2.array_value ()); return r; } template <typename M1, typename M2> auto msm_sub_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MArray<decltype (M1 () - M2 ())> { MArray<decltype (M1 () - M2 ())> r; octave_idx_type m1_nr = m1.rows (); octave_idx_type m1_nc = m1.cols (); octave_idx_type m2_nr = m2.rows (); octave_idx_type m2_nc = m2.cols (); if (m2_nr == 1 && m2_nc == 1) r = m1 - m2.elem (0,0); else if (m1_nr != m2_nr || m1_nc != m2_nc) octave::err_nonconformant ("operator -", m1_nr, m1_nc, m2_nr, m2_nc); else r = m1 - MArray<M2> (m2.array_value ()); return r; } template <typename M1, typename M2> auto msm_mul_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () * M2 ())> { typedef decltype (M1 () * M2 ()) RT; MSparse<RT> r; octave_idx_type m1_nr = m1.rows (); octave_idx_type m1_nc = m1.cols (); octave_idx_type m2_nr = m2.rows (); octave_idx_type m2_nc = m2.cols (); if (m2_nr == 1 && m2_nc == 1) r = MSparse<RT> (m1 * m2.elem (0,0)); else if (m1_nr != m2_nr || m1_nc != m2_nc) octave::err_nonconformant ("operator *", m1_nr, m1_nc, m2_nr, m2_nc); else { if (do_mx_check (m1, mx_inline_all_finite<M1>)) { /* Sparsity pattern is preserved. */ octave_idx_type m2_nz = m2.nnz (); r = MSparse<RT> (m2_nr, m2_nc, m2_nz); for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) { octave_quit (); for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) { octave_idx_type mri = m2.ridx (i); RT x = m1(mri, j) * m2.data (i); if (x != 0.0) { r.xdata (k) = x; r.xridx (k) = m2.ridx (i); k++; } } r.xcidx (j+1) = k; } r.maybe_compress (false); return r; } else r = MSparse<RT> (product (m1, MArray<M2> (m2.array_value ()))); } return r; } template <typename M1, typename M2> auto msm_div_op (const MArray<M1>& m1, const MSparse<M2>& m2) -> MSparse<decltype (M1 () / M2 ())> { typedef decltype (M1 () / M2 ()) RT; MSparse<RT> r; octave_idx_type m1_nr = m1.rows (); octave_idx_type m1_nc = m1.cols (); octave_idx_type m2_nr = m2.rows (); octave_idx_type m2_nc = m2.cols (); if (m2_nr == 1 && m2_nc == 1) r = MSparse<RT> (m1 / m2.elem (0,0)); else if (m1_nr != m2_nr || m1_nc != m2_nc) octave::err_nonconformant ("operator /", m1_nr, m1_nc, m2_nr, m2_nc); else { if (do_mx_check (m1, mx_inline_all_finite<M1>)) { /* Sparsity pattern is preserved. */ octave_idx_type m2_nz = m2.nnz (); r = MSparse<RT> (m2_nr, m2_nc, m2_nz); for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) { octave_quit (); for (octave_idx_type i = m2.cidx (j); i < m2.cidx (j+1); i++) { octave_idx_type mri = m2.ridx (i); RT x = m1(mri, j) / m2.data (i); if (x != 0.0) { r.xdata (k) = x; r.xridx (k) = m2.ridx (i); k++; } } r.xcidx (j+1) = k; } r.maybe_compress (false); return r; } else r = MSparse<RT> (quotient (m1, MArray<M2> (m2.array_value ()))); } return r; } #define SPARSE_MSM_BIN_OP(R, F, OP_FN, M1, M2) \ R \ F (const M1& m1, const M2& m2) \ { \ return OP_FN (m1, m2); \ } #define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \ SPARSE_MSM_BIN_OP (R1, operator +, msm_add_op, M1, M2) \ SPARSE_MSM_BIN_OP (R1, operator -, msm_sub_op, M1, M2) \ SPARSE_MSM_BIN_OP (R2, product, msm_mul_op, M1, M2) \ SPARSE_MSM_BIN_OP (R2, quotient, msm_div_op, M1, M2) #define SPARSE_MSM_CMP_OP(F, OP, M1, M2) \ SparseBoolMatrix \ F (const M1& m1, const M2& m2) \ { \ SparseBoolMatrix r; \ \ octave_idx_type m1_nr = m1.rows (); \ octave_idx_type m1_nc = m1.cols (); \ \ octave_idx_type m2_nr = m2.rows (); \ octave_idx_type m2_nc = m2.cols (); \ \ if (m2_nr == 1 && m2_nc == 1) \ r = SparseBoolMatrix (F (m1, m2.elem (0,0))); \ else if (m1_nr == m2_nr && m1_nc == m2_nc) \ { \ if (m1_nr != 0 || m1_nc != 0) \ { \ /* Count num of nonzero elements */ \ octave_idx_type nel = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ for (octave_idx_type i = 0; i < m1_nr; i++) \ if (m1.elem (i, j) OP m2.elem (i, j)) \ nel++; \ \ r = SparseBoolMatrix (m1_nr, m1_nc, nel); \ \ octave_idx_type ii = 0; \ r.cidx (0) = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ { \ for (octave_idx_type i = 0; i < m1_nr; i++) \ { \ bool el = m1.elem (i, j) OP m2.elem (i, j); \ if (el) \ { \ r.data (ii) = el; \ r.ridx (ii++) = i; \ } \ } \ r.cidx (j+1) = ii; \ } \ } \ } \ else \ { \ if ((m1_nr != 0 || m1_nc != 0) && (m2_nr != 0 || m2_nc != 0)) \ octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ } \ return r; \ } #define SPARSE_MSM_CMP_OPS(M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_lt, <, M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_le, <=, M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_ge, >=, M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_gt, >, M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_eq, ==, M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_ne, !=, M1, M2) #define SPARSE_MSM_EQNE_OPS(M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_eq, ==, M1, M2) \ SPARSE_MSM_CMP_OP (mx_el_ne, !=, M1, M2) #define SPARSE_MSM_BOOL_OP(F, OP, M1, M2) \ SparseBoolMatrix \ F (const M1& m1, const M2& m2) \ { \ SparseBoolMatrix r; \ \ octave_idx_type m1_nr = m1.rows (); \ octave_idx_type m1_nc = m1.cols (); \ \ octave_idx_type m2_nr = m2.rows (); \ octave_idx_type m2_nc = m2.cols (); \ \ M1::element_type lhs_zero = M1::element_type (); \ M2::element_type rhs_zero = M2::element_type (); \ \ if (m2_nr == 1 && m2_nc == 1) \ r = SparseBoolMatrix (F (m1, m2.elem (0,0))); \ else if (m1_nr == m2_nr && m1_nc == m2_nc) \ { \ if (m1_nr != 0 || m1_nc != 0) \ { \ /* Count num of nonzero elements */ \ octave_idx_type nel = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ for (octave_idx_type i = 0; i < m1_nr; i++) \ if ((m1.elem (i, j) != lhs_zero) \ OP (m2.elem (i, j) != rhs_zero)) \ nel++; \ \ r = SparseBoolMatrix (m1_nr, m1_nc, nel); \ \ octave_idx_type ii = 0; \ r.cidx (0) = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ { \ for (octave_idx_type i = 0; i < m1_nr; i++) \ { \ bool el = (m1.elem (i, j) != lhs_zero) \ OP (m2.elem (i, j) != rhs_zero); \ if (el) \ { \ r.data (ii) = el; \ r.ridx (ii++) = i; \ } \ } \ r.cidx (j+1) = ii; \ } \ } \ } \ else \ { \ if ((m1_nr != 0 || m1_nc != 0) && (m2_nr != 0 || m2_nc != 0)) \ octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ } \ return r; \ } #define SPARSE_MSM_BOOL_OPS(M1, M2) \ SPARSE_MSM_BOOL_OP (mx_el_and, &&, M1, M2) \ SPARSE_MSM_BOOL_OP (mx_el_or, ||, M1, M2) // sparse matrix by matrix operations. #define SPARSE_SMM_BIN_OP_1(R, F, OP, M1, M2) \ R \ F (const M1& m1, const M2& m2) \ { \ R r; \ \ octave_idx_type m1_nr = m1.rows (); \ octave_idx_type m1_nc = m1.cols (); \ \ octave_idx_type m2_nr = m2.rows (); \ octave_idx_type m2_nc = m2.cols (); \ \ if (m1_nr == 1 && m1_nc == 1) \ r = R (m1.elem (0,0) OP m2); \ else if (m1_nr != m2_nr || m1_nc != m2_nc) \ octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ else \ { \ r = R (m1.matrix_value () OP m2); \ } \ return r; \ } // sm .* m preserves sparsity if m contains no Infs nor Nans. #define SPARSE_SMM_BIN_OP_2_CHECK_product(ET) \ do_mx_check (m2, mx_inline_all_finite<ET>) // sm ./ m preserves sparsity if m contains no NaNs or zeros. #define SPARSE_SMM_BIN_OP_2_CHECK_quotient(ET) \ ! do_mx_check (m2, mx_inline_any_nan<ET>) && m2.nnz () == m2.numel () #define SPARSE_SMM_BIN_OP_2(R, F, OP, M1, M2) \ R \ F (const M1& m1, const M2& m2) \ { \ R r; \ \ octave_idx_type m1_nr = m1.rows (); \ octave_idx_type m1_nc = m1.cols (); \ \ octave_idx_type m2_nr = m2.rows (); \ octave_idx_type m2_nc = m2.cols (); \ \ if (m1_nr == 1 && m1_nc == 1) \ r = R (m1.elem (0,0) OP m2); \ else if (m1_nr != m2_nr || m1_nc != m2_nc) \ octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ else \ { \ if (SPARSE_SMM_BIN_OP_2_CHECK_ ## F(M2::element_type)) \ { \ /* Sparsity pattern is preserved. */ \ octave_idx_type m1_nz = m1.nnz (); \ r = R (m1_nr, m1_nc, m1_nz); \ for (octave_idx_type j = 0, k = 0; j < m1_nc; j++) \ { \ octave_quit (); \ for (octave_idx_type i = m1.cidx (j); i < m1.cidx (j+1); i++) \ { \ octave_idx_type mri = m1.ridx (i); \ R::element_type x = m1.data (i) OP m2 (mri, j); \ if (x != 0.0) \ { \ r.xdata (k) = x; \ r.xridx (k) = m1.ridx (i); \ k++; \ } \ } \ r.xcidx (j+1) = k; \ } \ r.maybe_compress (false); \ return r; \ } \ else \ r = R (F (m1.matrix_value (), m2)); \ } \ \ return r; \ } #define SPARSE_SMM_BIN_OPS(R1, R2, M1, M2) \ SPARSE_SMM_BIN_OP_1 (R1, operator +, +, M1, M2) \ SPARSE_SMM_BIN_OP_1 (R1, operator -, -, M1, M2) \ SPARSE_SMM_BIN_OP_2 (R2, product, *, M1, M2) \ SPARSE_SMM_BIN_OP_2 (R2, quotient, /, M1, M2) #define SPARSE_SMM_CMP_OP(F, OP, M1, M2) \ SparseBoolMatrix \ F (const M1& m1, const M2& m2) \ { \ SparseBoolMatrix r; \ \ octave_idx_type m1_nr = m1.rows (); \ octave_idx_type m1_nc = m1.cols (); \ \ octave_idx_type m2_nr = m2.rows (); \ octave_idx_type m2_nc = m2.cols (); \ \ if (m1_nr == 1 && m1_nc == 1) \ r = SparseBoolMatrix (F (m1.elem (0,0), m2)); \ else if (m1_nr == m2_nr && m1_nc == m2_nc) \ { \ if (m1_nr != 0 || m1_nc != 0) \ { \ /* Count num of nonzero elements */ \ octave_idx_type nel = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ for (octave_idx_type i = 0; i < m1_nr; i++) \ if (m1.elem (i, j) OP m2.elem (i, j)) \ nel++; \ \ r = SparseBoolMatrix (m1_nr, m1_nc, nel); \ \ octave_idx_type ii = 0; \ r.cidx (0) = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ { \ for (octave_idx_type i = 0; i < m1_nr; i++) \ { \ bool el = m1.elem (i, j) OP m2.elem (i, j); \ if (el) \ { \ r.data (ii) = el; \ r.ridx (ii++) = i; \ } \ } \ r.cidx (j+1) = ii; \ } \ } \ } \ else \ { \ if ((m1_nr != 0 || m1_nc != 0) && (m2_nr != 0 || m2_nc != 0)) \ octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ } \ return r; \ } #define SPARSE_SMM_CMP_OPS(M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_lt, <, M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_le, <=, M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_ge, >=, M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_gt, >, M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_eq, ==, M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_ne, !=, M1, M2) #define SPARSE_SMM_EQNE_OPS(M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_eq, ==, M1, M2) \ SPARSE_SMM_CMP_OP (mx_el_ne, !=, M1, M2) #define SPARSE_SMM_BOOL_OP(F, OP, M1, M2) \ SparseBoolMatrix \ F (const M1& m1, const M2& m2) \ { \ SparseBoolMatrix r; \ \ octave_idx_type m1_nr = m1.rows (); \ octave_idx_type m1_nc = m1.cols (); \ \ octave_idx_type m2_nr = m2.rows (); \ octave_idx_type m2_nc = m2.cols (); \ \ M1::element_type lhs_zero = M1::element_type (); \ M2::element_type rhs_zero = M2::element_type (); \ \ if (m1_nr == 1 && m1_nc == 1) \ r = SparseBoolMatrix (F (m1.elem (0,0), m2)); \ else if (m1_nr == m2_nr && m1_nc == m2_nc) \ { \ if (m1_nr != 0 || m1_nc != 0) \ { \ /* Count num of nonzero elements */ \ octave_idx_type nel = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ for (octave_idx_type i = 0; i < m1_nr; i++) \ if ((m1.elem (i, j) != lhs_zero) \ OP (m2.elem (i, j) != rhs_zero)) \ nel++; \ \ r = SparseBoolMatrix (m1_nr, m1_nc, nel); \ \ octave_idx_type ii = 0; \ r.cidx (0) = 0; \ for (octave_idx_type j = 0; j < m1_nc; j++) \ { \ for (octave_idx_type i = 0; i < m1_nr; i++) \ { \ bool el = (m1.elem (i, j) != lhs_zero) \ OP (m2.elem (i, j) != rhs_zero); \ if (el) \ { \ r.data (ii) = el; \ r.ridx (ii++) = i; \ } \ } \ r.cidx (j+1) = ii; \ } \ } \ } \ else \ { \ if ((m1_nr != 0 || m1_nc != 0) && (m2_nr != 0 || m2_nc != 0)) \ octave::err_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \ } \ return r; \ } #define SPARSE_SMM_BOOL_OPS(M1, M2) \ SPARSE_SMM_BOOL_OP (mx_el_and, &&, M1, M2) \ SPARSE_SMM_BOOL_OP (mx_el_or, ||, M1, M2) // Avoid some code duplication. Maybe we should use templates. #define SPARSE_CUMSUM(RET_TYPE, ELT_TYPE, FCN) \ \ octave_idx_type nr = rows (); \ octave_idx_type nc = cols (); \ \ RET_TYPE retval; \ \ if (nr > 0 && nc > 0) \ { \ if ((nr == 1 && dim == -1) || dim == 1) \ /* Ugly!! Is there a better way? */ \ retval = transpose (). FCN (0) .transpose (); \ else \ { \ octave_idx_type nel = 0; \ for (octave_idx_type i = 0; i < nc; i++) \ { \ ELT_TYPE t = ELT_TYPE (); \ for (octave_idx_type j = cidx (i); j < cidx (i+1); j++) \ { \ t += data (j); \ if (t != ELT_TYPE ()) \ { \ if (j == cidx (i+1) - 1) \ nel += nr - ridx (j); \ else \ nel += ridx (j+1) - ridx (j); \ } \ } \ } \ retval = RET_TYPE (nr, nc, nel); \ retval.cidx (0) = 0; \ octave_idx_type ii = 0; \ for (octave_idx_type i = 0; i < nc; i++) \ { \ ELT_TYPE t = ELT_TYPE (); \ for (octave_idx_type j = cidx (i); j < cidx (i+1); j++) \ { \ t += data (j); \ if (t != ELT_TYPE ()) \ { \ if (j == cidx (i+1) - 1) \ { \ for (octave_idx_type k = ridx (j); k < nr; k++) \ { \ retval.data (ii) = t; \ retval.ridx (ii++) = k; \ } \ } \ else \ { \ for (octave_idx_type k = ridx (j); k < ridx (j+1); k++) \ { \ retval.data (ii) = t; \ retval.ridx (ii++) = k; \ } \ } \ } \ } \ retval.cidx (i+1) = ii; \ } \ } \ } \ else \ retval = RET_TYPE (nr,nc); \ \ return retval #define SPARSE_CUMPROD(RET_TYPE, ELT_TYPE, FCN) \ \ octave_idx_type nr = rows (); \ octave_idx_type nc = cols (); \ \ RET_TYPE retval; \ \ if (nr > 0 && nc > 0) \ { \ if ((nr == 1 && dim == -1) || dim == 1) \ /* Ugly!! Is there a better way? */ \ retval = transpose (). FCN (0) .transpose (); \ else \ { \ octave_idx_type nel = 0; \ for (octave_idx_type i = 0; i < nc; i++) \ { \ octave_idx_type jj = 0; \ for (octave_idx_type j = cidx (i); j < cidx (i+1); j++) \ { \ if (jj == ridx (j)) \ { \ nel++; \ jj++; \ } \ else \ break; \ } \ } \ retval = RET_TYPE (nr, nc, nel); \ retval.cidx (0) = 0; \ octave_idx_type ii = 0; \ for (octave_idx_type i = 0; i < nc; i++) \ { \ ELT_TYPE t = ELT_TYPE (1.); \ octave_idx_type jj = 0; \ for (octave_idx_type j = cidx (i); j < cidx (i+1); j++) \ { \ if (jj == ridx (j)) \ { \ t *= data (j); \ retval.data (ii) = t; \ retval.ridx (ii++) = jj++; \ } \ else \ break; \ } \ retval.cidx (i+1) = ii; \ } \ } \ } \ else \ retval = RET_TYPE (nr,nc); \ \ return retval #define SPARSE_BASE_REDUCTION_OP(RET_TYPE, EL_TYPE, ROW_EXPR, COL_EXPR, \ INIT_VAL, MT_RESULT) \ \ octave_idx_type nr = rows (); \ octave_idx_type nc = cols (); \ \ RET_TYPE retval; \ \ if (nr > 0 && nc > 0) \ { \ if ((nr == 1 && dim == -1) || dim == 1) \ { \ /* Define j here to allow fancy definition for prod method */ \ octave_idx_type j = 0; \ OCTAVE_LOCAL_BUFFER (EL_TYPE, tmp, nr); \ \ for (octave_idx_type i = 0; i < nr; i++) \ tmp[i] = INIT_VAL; \ for (j = 0; j < nc; j++) \ { \ for (octave_idx_type i = cidx (j); i < cidx (j + 1); i++) \ { \ ROW_EXPR; \ } \ } \ octave_idx_type nel = 0; \ for (octave_idx_type i = 0; i < nr; i++) \ if (tmp[i] != EL_TYPE ()) \ nel++; \ retval = RET_TYPE (nr, static_cast<octave_idx_type> (1), nel); \ retval.cidx (0) = 0; \ retval.cidx (1) = nel; \ nel = 0; \ for (octave_idx_type i = 0; i < nr; i++) \ if (tmp[i] != EL_TYPE ()) \ { \ retval.data (nel) = tmp[i]; \ retval.ridx (nel++) = i; \ } \ } \ else \ { \ OCTAVE_LOCAL_BUFFER (EL_TYPE, tmp, nc); \ \ for (octave_idx_type j = 0; j < nc; j++) \ { \ tmp[j] = INIT_VAL; \ for (octave_idx_type i = cidx (j); i < cidx (j + 1); i++) \ { \ COL_EXPR; \ } \ } \ octave_idx_type nel = 0; \ for (octave_idx_type i = 0; i < nc; i++) \ if (tmp[i] != EL_TYPE ()) \ nel++; \ retval = RET_TYPE (static_cast<octave_idx_type> (1), nc, nel); \ retval.cidx (0) = 0; \ nel = 0; \ for (octave_idx_type i = 0; i < nc; i++) \ if (tmp[i] != EL_TYPE ()) \ { \ retval.data (nel) = tmp[i]; \ retval.ridx (nel++) = 0; \ retval.cidx (i+1) = retval.cidx (i) + 1; \ } \ else \ retval.cidx (i+1) = retval.cidx (i); \ } \ } \ else if (nc == 0 && (nr == 0 || (nr == 1 && dim == -1))) \ { \ if (MT_RESULT) \ { \ retval = RET_TYPE (static_cast<octave_idx_type> (1), \ static_cast<octave_idx_type> (1), \ static_cast<octave_idx_type> (1)); \ retval.cidx (0) = 0; \ retval.cidx (1) = 1; \ retval.ridx (0) = 0; \ retval.data (0) = MT_RESULT; \ } \ else \ retval = RET_TYPE (static_cast<octave_idx_type> (1), \ static_cast<octave_idx_type> (1), \ static_cast<octave_idx_type> (0)); \ } \ else if (nr == 0 && (dim == 0 || dim == -1)) \ { \ if (MT_RESULT) \ { \ retval = RET_TYPE (static_cast<octave_idx_type> (1), nc, nc); \ retval.cidx (0) = 0; \ for (octave_idx_type i = 0; i < nc ; i++) \ { \ retval.ridx (i) = 0; \ retval.cidx (i+1) = i+1; \ retval.data (i) = MT_RESULT; \ } \ } \ else \ retval = RET_TYPE (static_cast<octave_idx_type> (1), nc, \ static_cast<octave_idx_type> (0)); \ } \ else if (nc == 0 && dim == 1) \ { \ if (MT_RESULT) \ { \ retval = RET_TYPE (nr, static_cast<octave_idx_type> (1), nr); \ retval.cidx (0) = 0; \ retval.cidx (1) = nr; \ for (octave_idx_type i = 0; i < nr; i++) \ { \ retval.ridx (i) = i; \ retval.data (i) = MT_RESULT; \ } \ } \ else \ retval = RET_TYPE (nr, static_cast<octave_idx_type> (1), \ static_cast<octave_idx_type> (0)); \ } \ else \ retval.resize (nr > 0, nc > 0); \ \ return retval #define SPARSE_REDUCTION_OP_ROW_EXPR(OP) \ tmp[ridx (i)] OP data (i) #define SPARSE_REDUCTION_OP_COL_EXPR(OP) \ tmp[j] OP data (i) #define SPARSE_REDUCTION_OP(RET_TYPE, EL_TYPE, OP, INIT_VAL, MT_RESULT) \ SPARSE_BASE_REDUCTION_OP (RET_TYPE, EL_TYPE, \ SPARSE_REDUCTION_OP_ROW_EXPR (OP), \ SPARSE_REDUCTION_OP_COL_EXPR (OP), \ INIT_VAL, MT_RESULT) // Don't break from this loop if the test succeeds because // we are looping over the rows and not the columns in the inner loop. #define SPARSE_ANY_ALL_OP_ROW_CODE(TEST_OP, TEST_TRUE_VAL) \ if (data (i) TEST_OP 0.0) \ tmp[ridx (i)] = TEST_TRUE_VAL; #define SPARSE_ANY_ALL_OP_COL_CODE(TEST_OP, TEST_TRUE_VAL) \ if (data (i) TEST_OP 0.0) \ { \ tmp[j] = TEST_TRUE_VAL; \ break; \ } #define SPARSE_ANY_ALL_OP(DIM, INIT_VAL, MT_RESULT, TEST_OP, TEST_TRUE_VAL) \ SPARSE_BASE_REDUCTION_OP (SparseBoolMatrix, char, \ SPARSE_ANY_ALL_OP_ROW_CODE (TEST_OP, TEST_TRUE_VAL), \ SPARSE_ANY_ALL_OP_COL_CODE (TEST_OP, TEST_TRUE_VAL), \ INIT_VAL, MT_RESULT) #define SPARSE_ALL_OP(DIM) \ if ((rows () == 1 && dim == -1) || dim == 1) \ return transpose (). all (0). transpose (); \ else \ { \ SPARSE_ANY_ALL_OP (DIM, (cidx (j+1) - cidx (j) < nr ? false : true), \ true, ==, false); \ } #define SPARSE_ANY_OP(DIM) SPARSE_ANY_ALL_OP (DIM, false, false, !=, true) template <typename RET_T, typename SM1, typename SM2> RET_T sparse_sparse_mul (const SM1& m, const SM2 a) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); octave_idx_type a_nr = a.rows (); octave_idx_type a_nc = a.cols (); if (nr == 1 && nc == 1) { typename RET_T::element_type s = m.elem (0,0); octave_idx_type nz = a.nnz (); RET_T r (a_nr, a_nc, nz); for (octave_idx_type i = 0; i < nz; i++) { octave_quit (); r.data (i) = s * a.data (i); r.ridx (i) = a.ridx (i); } for (octave_idx_type i = 0; i < a_nc + 1; i++) { octave_quit (); r.cidx (i) = a.cidx (i); } r.maybe_compress (true); return r; } else if (a_nr == 1 && a_nc == 1) { typename RET_T::element_type s = a.elem (0,0); octave_idx_type nz = m.nnz (); RET_T r (nr, nc, nz); for (octave_idx_type i = 0; i < nz; i++) { octave_quit (); r.data (i) = m.data (i) * s; r.ridx (i) = m.ridx (i); } for (octave_idx_type i = 0; i < nc + 1; i++) { octave_quit (); r.cidx (i) = m.cidx (i); } r.maybe_compress (true); return r; } else if (nc != a_nr) octave::err_nonconformant ("operator *", nr, nc, a_nr, a_nc); else { OCTAVE_LOCAL_BUFFER (octave_idx_type, w, nr); RET_T retval (nr, a_nc, static_cast<octave_idx_type> (0)); for (octave_idx_type i = 0; i < nr; i++) w[i] = 0; retval.xcidx (0) = 0; octave_idx_type nel = 0; for (octave_idx_type i = 0; i < a_nc; i++) { for (octave_idx_type j = a.cidx (i); j < a.cidx (i+1); j++) { octave_idx_type col = a.ridx (j); for (octave_idx_type k = m.cidx (col) ; k < m.cidx (col+1); k++) { if (w[m.ridx (k)] < i + 1) { w[m.ridx (k)] = i + 1; nel++; } octave_quit (); } } retval.xcidx (i+1) = nel; } if (nel == 0) return RET_T (nr, a_nc); else { for (octave_idx_type i = 0; i < nr; i++) w[i] = 0; OCTAVE_LOCAL_BUFFER (typename RET_T::element_type, Xcol, nr); retval.change_capacity (nel); /* The optimal break-point as estimated from simulations */ /* Note that Mergesort is O(nz log(nz)) while searching all */ /* values is O(nr), where nz here is nonzero per row of */ /* length nr. The test itself was then derived from the */ /* simulation with random square matrices and the observation */ /* of the number of nonzero elements in the output matrix */ /* it was found that the breakpoints were */ /* nr: 500 1000 2000 5000 10000 */ /* nz: 6 25 97 585 2202 */ /* The below is a simplication of the 'polyfit'-ed parameters */ /* to these breakpoints */ octave_idx_type n_per_col = (a_nc > 43000 ? 43000 : (a_nc * a_nc) / 43000); octave_idx_type ii = 0; octave_idx_type *ri = retval.xridx (); octave_sort<octave_idx_type> sort; for (octave_idx_type i = 0; i < a_nc ; i++) { if (retval.xcidx (i+1) - retval.xcidx (i) > n_per_col) { for (octave_idx_type j = a.cidx (i); j < a.cidx (i+1); j++) { octave_idx_type col = a.ridx (j); typename SM2::element_type tmpval = a.data (j); for (octave_idx_type k = m.cidx (col) ; k < m.cidx (col+1); k++) { octave_quit (); octave_idx_type row = m.ridx (k); if (w[row] < i + 1) { w[row] = i + 1; Xcol[row] = tmpval * m.data (k); } else Xcol[row] += tmpval * m.data (k); } } for (octave_idx_type k = 0; k < nr; k++) if (w[k] == i + 1) { retval.xdata (ii) = Xcol[k]; retval.xridx (ii++) = k; } } else { for (octave_idx_type j = a.cidx (i); j < a.cidx (i+1); j++) { octave_idx_type col = a.ridx (j); typename SM2::element_type tmpval = a.data (j); for (octave_idx_type k = m.cidx (col) ; k < m.cidx (col+1); k++) { octave_quit (); octave_idx_type row = m.ridx (k); if (w[row] < i + 1) { w[row] = i + 1; retval.xridx (ii++) = row; Xcol[row] = tmpval * m.data (k); } else Xcol[row] += tmpval * m.data (k); } } sort.sort (ri + retval.xcidx (i), ii - retval.xcidx (i)); for (octave_idx_type k = retval.xcidx (i); k < ii; k++) retval.xdata (k) = Xcol[retval.xridx (k)]; } } retval.maybe_compress (true); return retval; } } } template <typename RET_T, typename SMT, typename MT> RET_T sparse_full_mul (const SMT& m, const MT& a) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); octave_idx_type a_nr = a.rows (); octave_idx_type a_nc = a.cols (); if (nr == 1 && nc == 1) { RET_T retval = m.elem (0,0) * a; return retval; } else if (nc != a_nr) octave::err_nonconformant ("operator *", nr, nc, a_nr, a_nc); else { typename RET_T::element_type zero = typename RET_T::element_type (); RET_T retval (nr, a_nc, zero); for (octave_idx_type i = 0; i < a_nc ; i++) { for (octave_idx_type j = 0; j < a_nr; j++) { octave_quit (); typename MT::element_type tmpval = a.elem (j,i); for (octave_idx_type k = m.cidx (j) ; k < m.cidx (j+1); k++) retval.elem (m.ridx (k),i) += tmpval * m.data (k); } } return retval; } } template <typename RET_T, typename SMT, typename MT, typename CONJ_OP> RET_T sparse_full_trans_mul (const SMT& m, const MT& a, CONJ_OP conj_op) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); octave_idx_type a_nr = a.rows (); octave_idx_type a_nc = a.cols (); if (nr == 1 && nc == 1) { RET_T retval = conj_op (m.elem (0,0)) * a; return retval; } else if (nr != a_nr) octave::err_nonconformant ("operator *", nc, nr, a_nr, a_nc); else { RET_T retval (nc, a_nc); for (octave_idx_type i = 0; i < a_nc ; i++) { for (octave_idx_type j = 0; j < nc; j++) { octave_quit (); typename RET_T::element_type acc = typename RET_T::element_type (); for (octave_idx_type k = m.cidx (j) ; k < m.cidx (j+1); k++) acc += a.elem (m.ridx (k),i) * conj_op (m.data (k)); retval.xelem (j,i) = acc; } } return retval; } } template <typename RET_T, typename MT, typename SMT> RET_T full_sparse_mul (const MT& m, const SMT& a) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); octave_idx_type a_nr = a.rows (); octave_idx_type a_nc = a.cols (); if (a_nr == 1 && a_nc == 1) { RET_T retval = m * a.elem (0,0); return retval; } else if (nc != a_nr) octave::err_nonconformant ("operator *", nr, nc, a_nr, a_nc); else { typename RET_T::element_type zero = typename RET_T::element_type (); RET_T retval (nr, a_nc, zero); for (octave_idx_type i = 0; i < a_nc ; i++) { octave_quit (); for (octave_idx_type j = a.cidx (i); j < a.cidx (i+1); j++) { octave_idx_type col = a.ridx (j); typename SMT::element_type tmpval = a.data (j); for (octave_idx_type k = 0 ; k < nr; k++) retval.xelem (k,i) += tmpval * m.elem (k,col); } } return retval; } } template <typename RET_T, typename MT, typename SMT, typename CONJ_OP> RET_T full_sparse_mul_trans (const MT& m, const SMT& a, CONJ_OP conj_op) { octave_idx_type nr = m.rows (); octave_idx_type nc = m.cols (); octave_idx_type a_nr = a.rows (); octave_idx_type a_nc = a.cols (); if (a_nr == 1 && a_nc == 1) { RET_T retval = m * conj_op (a.elem (0,0)); return retval; } else if (nc != a_nc) octave::err_nonconformant ("operator *", nr, nc, a_nc, a_nr); else { typename RET_T::element_type zero = typename RET_T::element_type (); RET_T retval (nr, a_nr, zero); for (octave_idx_type i = 0; i < a_nc ; i++) { octave_quit (); for (octave_idx_type j = a.cidx (i); j < a.cidx (i+1); j++) { octave_idx_type col = a.ridx (j); typename SMT::element_type tmpval = conj_op (a.data (j)); for (octave_idx_type k = 0 ; k < nr; k++) retval.xelem (k,col) += tmpval * m.elem (k,i); } } return retval; } } #endif