comparison libinterp/corefcn/sparse-xpow.cc @ 31224:45984c799215

sparse-xpow.cc: Use faster multiplication technique based on input matrix sparsity
author Arun Giridhar <arungiridhar@gmail.com>
date Sat, 10 Sep 2022 15:44:05 -0400
parents 796f54d4ddbf
children 3eab70385569
comparison
equal deleted inserted replaced
31223:94488ab70e12 31224:45984c799215
107 107
108 SparseMatrix result (atmp); 108 SparseMatrix result (atmp);
109 109
110 btmp--; 110 btmp--;
111 111
112 while (btmp > 0) 112 // There are two approaches to the actual exponentiation.
113 { 113 // Exponentiation by squaring uses only a logarithmic number
114 if (btmp & 1) 114 // of multiplications but the matrices it multiplies tend to be dense
115 // towards the end.
116 // Linear multiplication uses a linear number of multiplications
117 // but one of the matrices it uses will be as sparse as the original matrix.
118 //
119 // The time to multiply fixed-size matrices is strongly affected by their
120 // sparsity. Denser matrices take much longer to multiply together.
121 // See this URL for a worked-through example:
122 // https://octave.discourse.group/t/3216/4
123 //
124 // The tradeoff is between many fast multiplications or a few slow ones.
125 //
126 // Large exponents favor the squaring technique, and sparse matrices favor
127 // linear multiplication.
128 //
129 // We calculate a threshold based on the sparsity of the input
130 // and use squaring for exponents larger than that.
131 //
132 // FIXME: Improve this threshold calculation.
133
134 uint64_t sparsity = atmp.numel() / atmp.nnz(); // reciprocal of density
135 int threshold = (sparsity >= 10000) ? 40
136 : (sparsity >= 1000) ? 30
137 : (sparsity >= 100) ? 20
138 : 3;
139
140 if (btmp > threshold) // use squaring technique
141 {
142 while (btmp > 0)
143 {
144 if (btmp & 1)
145 result = result * atmp;
146
147 btmp >>= 1;
148
149 if (btmp > 0)
150 atmp = atmp * atmp;
151 }
152 }
153 else // use linear multiplication
154 {
155 for (int i = 0; i < btmp; i++)
115 result = result * atmp; 156 result = result * atmp;
116
117 btmp >>= 1;
118
119 if (btmp > 0)
120 atmp = atmp * atmp;
121 } 157 }
122 158
123 retval = result; 159 retval = result;
124 } 160 }
125 161