0
|
1 /* |
|
2 * conv2: 2D convolution for octave |
|
3 * |
|
4 * Copyright (C) 1999 Andy Adler |
|
5 * This code has no warrany whatsoever. |
|
6 * Do what you like with this code as long as you |
|
7 * leave this copyright in place. |
|
8 * |
|
9 * $Id$ |
|
10 |
|
11 ## 2000-05-17: Paul Kienzle |
|
12 ## * change argument to vector conversion to work for 2.1 series octave |
|
13 ## as well as 2.0 series |
|
14 ## 2001-02-05: Paul Kienzle |
|
15 ## * accept complex arguments |
|
16 |
|
17 */ |
|
18 |
|
19 #include <octave/oct.h> |
|
20 |
|
21 #define MAX(a,b) ((a) > (b) ? (a) : (b)) |
|
22 |
|
23 #define SHAPE_FULL 1 |
|
24 #define SHAPE_SAME 2 |
|
25 #define SHAPE_VALID 3 |
|
26 |
|
27 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
28 extern MArray2<double> |
|
29 conv2 (MArray<double>&, MArray<double>&, MArray2<double>&, int); |
|
30 |
|
31 extern MArray2<Complex> |
|
32 conv2 (MArray<Complex>&, MArray<Complex>&, MArray2<Complex>&, int); |
|
33 #endif |
|
34 |
|
35 template <class T> |
|
36 MArray2<T> |
|
37 conv2 (MArray<T>& R, MArray<T>& C, MArray2<T>& A, int ishape) |
|
38 { |
|
39 int Rn= R.length(); |
|
40 int Cm= C.length(); |
|
41 int Am = A.rows(); |
|
42 int An = A.columns(); |
|
43 |
|
44 /* |
|
45 * Here we calculate the size of the output matrix, |
|
46 * in order to stay Matlab compatible, it is based |
|
47 * on the third parameter if its separable, and the |
|
48 * first if it's not |
|
49 */ |
|
50 int outM, outN, edgM, edgN; |
|
51 if ( ishape == SHAPE_FULL ) { |
|
52 outM= Am + Cm - 1; |
|
53 outN= An + Rn - 1; |
|
54 edgM= Cm - 1; |
|
55 edgN= Rn - 1; |
|
56 } else if ( ishape == SHAPE_SAME ) { |
|
57 outM= Am; |
|
58 outN= An; |
|
59 // Matlab seems to arbitrarily choose this convention for |
|
60 // 'same' with even length R, C |
|
61 edgM= ( Cm - 1) /2; |
|
62 edgN= ( Rn - 1) /2; |
|
63 } else if ( ishape == SHAPE_VALID ) { |
|
64 outM= Am - Cm + 1; |
|
65 outN= An - Rn + 1; |
|
66 edgM= edgN= 0; |
|
67 } |
|
68 |
|
69 // printf("A(%d,%d) C(%d) R(%d) O(%d,%d) E(%d,%d)\n", |
|
70 // Am,An, Cm,Rn, outM, outN, edgM, edgN); |
|
71 MArray2<T> O(outM,outN); |
|
72 /* |
|
73 * T accumulated the 1-D conv for each row, before calculating |
|
74 * the convolution in the other direction |
|
75 * There is no efficiency advantage to doing it in either direction |
|
76 * first |
|
77 */ |
|
78 |
|
79 MArray<T> X( An ); |
|
80 |
|
81 for( int oi=0; oi < outM; oi++ ) { |
|
82 for( int oj=0; oj < An; oj++ ) { |
|
83 T sum=0; |
|
84 |
|
85 int ci= Cm - 1 - MAX(0, edgM-oi); |
|
86 int ai= MAX(0, oi-edgM) ; |
|
87 const T* Ad= A.data() + ai + Am*oj; |
|
88 const T* Cd= C.data() + ci; |
|
89 for( ; ci >= 0 && ai < Am; |
|
90 ci--, Cd--, ai++, Ad++) { |
|
91 sum+= (*Ad) * (*Cd); |
|
92 } // for( int ci= |
|
93 |
|
94 X(oj)= sum; |
|
95 } // for( int oj=0 |
|
96 |
|
97 for( int oj=0; oj < outN; oj++ ) { |
|
98 T sum=0; |
|
99 |
|
100 int rj= Rn - 1 - MAX(0, edgN-oj); |
|
101 int aj= MAX(0, oj-edgN) ; |
|
102 const T* Xd= X.data() + aj; |
|
103 const T* Rd= R.data() + rj; |
|
104 |
|
105 for( ; rj >= 0 && aj < An; |
|
106 rj--, Rd--, aj++, Xd++) { |
|
107 sum+= (*Xd) * (*Rd); |
|
108 } //for( int rj= |
|
109 |
|
110 O(oi,oj)= sum; |
|
111 } // for( int oj=0 |
|
112 } // for( int oi=0 |
|
113 |
|
114 return O; |
|
115 } |
|
116 |
|
117 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
118 extern MArray2<double> |
|
119 conv2 (MArray2<double>&, MArray2<double>&, int); |
|
120 |
|
121 extern MArray2<Complex> |
|
122 conv2 (MArray2<Complex>&, MArray2<Complex>&, int); |
|
123 #endif |
|
124 |
|
125 template <class T> |
|
126 MArray2<T> |
|
127 conv2 (MArray2<T>&A, MArray2<T>&B, int ishape) |
|
128 { |
|
129 /* Convolution works fastest if we choose the A matrix to be |
|
130 * the largest. |
|
131 * |
|
132 * Here we calculate the size of the output matrix, |
|
133 * in order to stay Matlab compatible, it is based |
|
134 * on the third parameter if its separable, and the |
|
135 * first if it's not |
|
136 * |
|
137 * NOTE in order to be Matlab compatible, we give |
|
138 * wrong sizes for 'valid' if the smallest matrix is first |
|
139 */ |
|
140 |
|
141 int Am = A.rows(); |
|
142 int An = A.columns(); |
|
143 int Bm = B.rows(); |
|
144 int Bn = B.columns(); |
|
145 |
|
146 int outM, outN, edgM, edgN; |
|
147 if ( ishape == SHAPE_FULL ) { |
|
148 outM= Am + Bm - 1; |
|
149 outN= An + Bn - 1; |
|
150 edgM= Bm - 1; |
|
151 edgN= Bn - 1; |
|
152 } else if ( ishape == SHAPE_SAME ) { |
|
153 outM= Am; |
|
154 outN= An; |
|
155 // Matlab seems to arbitrarily choose this convention for |
|
156 // 'same' with even length R, C |
|
157 edgM= ( Bm - 1) /2; |
|
158 edgN= ( Bn - 1) /2; |
|
159 } else if ( ishape == SHAPE_VALID ) { |
|
160 outM= Am - Bm + 1; |
|
161 outN= An - Bn + 1; |
|
162 edgM= edgN= 0; |
|
163 } |
|
164 |
|
165 // printf("A(%d,%d) B(%d,%d) O(%d,%d) E(%d,%d)\n", |
|
166 // Am,An, Bm,Bn, outM, outN, edgM, edgN); |
|
167 MArray2<T> O(outM,outN); |
|
168 |
|
169 for( int oi=0; oi < outM; oi++ ) { |
|
170 for( int oj=0; oj < outN; oj++ ) { |
|
171 T sum=0; |
|
172 |
|
173 for( int bj= Bn - 1 - MAX(0, edgN-oj), |
|
174 aj= MAX(0, oj-edgN); |
|
175 bj >= 0 && aj < An; |
|
176 bj--, aj++) { |
|
177 int bi= Bm - 1 - MAX(0, edgM-oi); |
|
178 int ai= MAX(0, oi-edgM); |
|
179 const T* Ad= A.data() + ai + Am*aj; |
|
180 const T* Bd= B.data() + bi + Bm*bj; |
|
181 |
|
182 for( ; bi >= 0 && ai < Am; |
|
183 bi--, Bd--, ai++, Ad++) { |
|
184 sum+= (*Ad) * (*Bd); |
|
185 /* |
|
186 * It seems to be about 2.5 times faster to use pointers than |
|
187 * to do this |
|
188 * sum+= A(ai,aj) * B(bi,bj); |
|
189 */ |
|
190 } // for( int bi= |
|
191 } //for( int bj= |
|
192 |
|
193 O(oi,oj)= sum; |
|
194 } // for( int oj= |
|
195 } // for( int oi= |
|
196 return O; |
|
197 } |
|
198 |
|
199 DEFUN_DLD (conv2, args, , |
|
200 "[...] = conv2 (...) |
|
201 CONV2: do 2 dimensional convolution |
|
202 |
|
203 c= conv2(a,b) -> same as c= conv2(a,b,'full') |
|
204 |
|
205 c= conv2(a,b,shape) returns 2-D convolution of a and b |
|
206 where the size of c is given by |
|
207 shape= 'full' -> returns full 2-D convolution |
|
208 shape= 'same' -> same size as a. 'central' part of convolution |
|
209 shape= 'valid' -> only parts which do not include zero-padded edges |
|
210 |
|
211 c= conv2(a,b,shape) returns 2-D convolution of a and b |
|
212 |
|
213 c= conv2(v1,v2,a) -> same as c= conv2(v1,v2,a,'full') |
|
214 |
|
215 c= conv2(v1,v2,a,shape) returns convolution of a by vector v1 |
|
216 in the column direction and vector v2 in the row direction ") |
|
217 { |
|
218 octave_value_list retval; |
|
219 octave_value tmp; |
|
220 int nargin = args.length (); |
|
221 string shape= "full"; |
|
222 bool separable= false; |
|
223 int ishape; |
|
224 |
|
225 if (nargin < 2 ) { |
|
226 print_usage ("conv2"); |
|
227 return retval; |
|
228 } else if (nargin == 3) { |
|
229 if ( args(2).is_string() ) |
|
230 shape= args(2).string_value(); |
|
231 else |
|
232 separable= true; |
|
233 } else if (nargin >= 4) { |
|
234 separable= true; |
|
235 shape= args(3).string_value(); |
|
236 } |
|
237 if ( shape == "full" ) ishape = SHAPE_FULL; |
|
238 else if ( shape == "same" ) ishape = SHAPE_SAME; |
|
239 else if ( shape == "valid" ) ishape = SHAPE_VALID; |
|
240 else { // if ( shape |
|
241 error("Shape type not valid"); |
|
242 print_usage ("conv2"); |
|
243 return retval; |
|
244 } |
|
245 |
|
246 if (separable) { |
|
247 /* |
|
248 * Check that the first two parameters are vectors |
|
249 * if we're doing separable |
|
250 */ |
|
251 if ( !( 1== args(0).rows() || 1== args(0).columns() ) || |
|
252 !( 1== args(1).rows() || 1== args(1).columns() ) ) { |
|
253 print_usage ("conv2"); |
|
254 return retval; |
|
255 } |
|
256 |
|
257 if (args(0).is_complex_type() || args(1).is_complex_type() |
|
258 || args(2).is_complex_type()) { |
|
259 ComplexColumnVector v1 (args(0).complex_vector_value()); |
|
260 ComplexColumnVector v2 (args(1).complex_vector_value()); |
|
261 ComplexMatrix a (args(2).complex_matrix_value()); |
|
262 ComplexMatrix c(conv2(v1, v2, a, ishape)); |
|
263 retval(0) = c; |
|
264 } else { |
|
265 ColumnVector v1 (args(0).vector_value()); |
|
266 ColumnVector v2 (args(1).vector_value()); |
|
267 Matrix a (args(2).matrix_value()); |
|
268 Matrix c(conv2(v1, v2, a, ishape)); |
|
269 retval(0) = c; |
|
270 } |
|
271 } else { // if (separable) |
|
272 |
|
273 if (args(0).is_complex_type() || args(1).is_complex_type()) { |
|
274 ComplexMatrix a (args(0).complex_matrix_value()); |
|
275 ComplexMatrix b (args(1).complex_matrix_value()); |
|
276 ComplexMatrix c(conv2(a, b, ishape)); |
|
277 retval(0) = c; |
|
278 } else { |
|
279 Matrix a (args(0).matrix_value()); |
|
280 Matrix b (args(1).matrix_value()); |
|
281 Matrix c(conv2(a, b, ishape)); |
|
282 retval(0) = c; |
|
283 } |
|
284 |
|
285 } // if (separable) |
|
286 |
|
287 return retval; |
|
288 } |
|
289 |
|
290 |
|
291 template MArray2<double> |
|
292 conv2 (MArray<double>&, MArray<double>&, MArray2<double>&, int); |
|
293 |
|
294 template MArray2<double> |
|
295 conv2 (MArray2<double>&, MArray2<double>&, int); |
|
296 |
|
297 template MArray2<Complex> |
|
298 conv2 (MArray<Complex>&, MArray<Complex>&, MArray2<Complex>&, int); |
|
299 |
|
300 template MArray2<Complex> |
|
301 conv2 (MArray2<Complex>&, MArray2<Complex>&, int); |