# HG changeset patch # User Jaroslav Hajek # Date 1267520652 -3600 # Node ID 978f5c94b11f1cc9994493e310bf985624800dbf # Parent 796258b94b225800f53f91a665f40c2729fca62b initial implementation of conv2/convn in liboctave diff -r 796258b94b22 -r 978f5c94b11f liboctave/ChangeLog --- a/liboctave/ChangeLog Tue Mar 02 01:14:11 2010 +0100 +++ b/liboctave/ChangeLog Tue Mar 02 10:04:12 2010 +0100 @@ -1,3 +1,8 @@ +2010-03-02 Jaroslav Hajek + + * oct-convn.h, oct-convn.cc: New sources. + * Makefile.am: Include them. + 2010-03-01 David Bateman * Sparse.cc (Sparse::maybe_delete_elements (idxx_vector&)): diff -r 796258b94b22 -r 978f5c94b11f liboctave/Makefile.am --- a/liboctave/Makefile.am Tue Mar 02 01:14:11 2010 +0100 +++ b/liboctave/Makefile.am Tue Mar 02 10:04:12 2010 +0100 @@ -211,6 +211,7 @@ mach-info.h \ oct-alloc.h \ oct-cmplx.h \ + oct-convn.h \ oct-env.h \ oct-fftw.h \ oct-group.h \ @@ -425,6 +426,7 @@ lo-utils.cc \ mach-info.cc \ oct-alloc.cc \ + oct-convn.cc \ oct-env.cc \ oct-fftw.cc \ oct-glob.cc \ diff -r 796258b94b22 -r 978f5c94b11f src/ChangeLog --- a/src/ChangeLog Tue Mar 02 01:14:11 2010 +0100 +++ b/src/ChangeLog Tue Mar 02 10:04:12 2010 +0100 @@ -1,3 +1,8 @@ +2010-03-02 Jaroslav Hajek + + * DLD-FUNCTIONS/conv.cc (Fconv2): Rewrite using convn from liboctave. + (Fconvn): New DEFUN. + 2010-03-01 John W. Eaton * (str2double1): Pass argument as const reference and make diff -r 796258b94b22 -r 978f5c94b11f src/DLD-FUNCTIONS/conv2.cc --- a/src/DLD-FUNCTIONS/conv2.cc Tue Mar 02 01:14:11 2010 +0100 +++ b/src/DLD-FUNCTIONS/conv2.cc Tue Mar 02 10:04:12 2010 +0100 @@ -25,6 +25,8 @@ #include #endif +#include "oct-convn.h" + #include "defun-dld.h" #include "error.h" #include "oct-obj.h" @@ -32,235 +34,6 @@ enum Shape { SHAPE_FULL, SHAPE_SAME, SHAPE_VALID }; -#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) -extern MArray -conv2 (const MArray&, const MArray&, const MArray&, - Shape); - -extern MArray -conv2 (const MArray&, const MArray&, - const MArray&, Shape); - -extern MArray -conv2 (const MArray&, const MArray&, const MArray&, - Shape); - -extern MArray -conv2 (const MArray&, const MArray&, - const MArray&, Shape); -#endif - -template -MArray -conv2 (const MArray& R, const MArray& C, const MArray& A, Shape ishape) -{ - octave_idx_type Rn = R.length (); - octave_idx_type Cm = C.length (); - octave_idx_type Am = A.rows (); - octave_idx_type An = A.columns (); - - // Calculate the size of the output matrix: - // in order to stay Matlab compatible, it is based - // on the third parameter if it's separable, and the - // first if it's not - - octave_idx_type outM = 0; - octave_idx_type outN = 0; - octave_idx_type edgM = 0; - octave_idx_type edgN = 0; - - switch (ishape) - { - case SHAPE_FULL: - outM = Am + Cm - 1; - outN = An + Rn - 1; - edgM = Cm - 1; - edgN = Rn - 1; - break; - - case SHAPE_SAME: - outM = Am; - outN = An; - // Follow the Matlab convention (ie + instead of -) - edgM = (Cm - 1) /2; - edgN = (Rn - 1) /2; - break; - - case SHAPE_VALID: - outM = Am - Cm + 1; - outN = An - Rn + 1; - if (outM < 0) - outM = 0; - if (outN < 0) - outN = 0; - edgM = edgN = 0; - break; - - default: - error ("conv2: invalid value of parameter ishape"); - } - - MArray O (outM, outN); - - T *Od0 = O.fortran_vec (); - - // X accumulates the 1-D conv for each row, before calculating - // the convolution in the other direction. There is no efficiency - // advantage to doing it in either direction first. - - MArray X (An, 1); - - T *Xd0 = X.fortran_vec (); - - for (octave_idx_type oi = 0; oi < outM; oi++) - { - T *Xd = Xd0; - - octave_idx_type ci0 = Cm - 1 - std::max (0, edgM-oi); - octave_idx_type ai0 = std::max (0, oi-edgM); - - const T *Cd0 = C.data () + ci0; - const T *Ad0 = A.data () + ai0; - - for (octave_idx_type oj = 0; oj < An; oj++) - { - T sum = 0; - - const T* Cd = Cd0; - const T* Ad = Ad0 + Am*oj; - - for (octave_idx_type ci = ci0, ai = ai0; ci >= 0 && ai < Am; - ci--, ai++) - sum += (*Ad++) * (*Cd--); - - *Xd++ = sum; - } - - T *Od = Od0 + oi; - - const T *Rd0 = R.data (); - - for (octave_idx_type oj = 0; oj < outN; oj++) - { - T sum = 0; - - octave_idx_type rj = Rn - 1 - std::max (0, edgN-oj); - octave_idx_type aj = std::max (0, oj-edgN); - - Xd = Xd0 + aj; - const T* Rd = Rd0 + rj; - - for ( ; rj >= 0 && aj < An; rj--, aj++) - sum += (*Xd++) * (*Rd--); - - *Od = sum; - Od += outM; - } - } - - return O; -} - -#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) -extern MArray -conv2 (MArray&, MArray&, Shape); - -extern MArray -conv2 (MArray&, MArray&, Shape); - -extern MArray -conv2 (MArray&, MArray&, Shape); - -extern MArray -conv2 (MArray&, MArray&, Shape); -#endif - -template -MArray -conv2 (const MArray& A, const MArray& B, Shape ishape) -{ - // Convolution works fastest if we choose the A matrix to be - // the largest. - - // Here we calculate the size of the output matrix, - // in order to stay Matlab compatible, it is based - // on the third parameter if it's separable, and the - // first if it's not - - // NOTE in order to be Matlab compatible, we give argueably - // wrong sizes for 'valid' if the smallest matrix is first - - octave_idx_type Am = A.rows (); - octave_idx_type An = A.columns (); - octave_idx_type Bm = B.rows (); - octave_idx_type Bn = B.columns (); - - octave_idx_type outM = 0; - octave_idx_type outN = 0; - octave_idx_type edgM = 0; - octave_idx_type edgN = 0; - - switch (ishape) - { - case SHAPE_FULL: - outM = Am + Bm - 1; - outN = An + Bn - 1; - edgM = Bm - 1; - edgN = Bn - 1; - break; - - case SHAPE_SAME: - outM = Am; - outN = An; - edgM = (Bm - 1) /2; - edgN = (Bn - 1) /2; - break; - - case SHAPE_VALID: - outM = Am - Bm + 1; - outN = An - Bn + 1; - if (outM < 0) - outM = 0; - if (outN < 0) - outN = 0; - edgM = edgN = 0; - break; - } - - MArray O (outM, outN); - - T *Od = O.fortran_vec (); - - for (octave_idx_type oj = 0; oj < outN; oj++) - { - octave_idx_type aj0 = std::max (0, oj-edgN); - octave_idx_type bj0 = Bn - 1 - std::max (0, edgN-oj); - - for (octave_idx_type oi = 0; oi < outM; oi++) - { - T sum = 0; - - octave_idx_type bi0 = Bm - 1 - std::max (0, edgM-oi); - octave_idx_type ai0 = std::max (0, oi-edgM); - - for (octave_idx_type aj = aj0, bj = bj0; bj >= 0 && aj < An; - bj--, aj++) - { - const T* Ad = A.data () + ai0 + Am*aj; - const T* Bd = B.data () + bi0 + Bm*bj; - - for (octave_idx_type ai = ai0, bi = bi0; bi >= 0 && ai < Am; - bi--, ai++) - sum += (*Ad++) * (*Bd--); - } - - *Od++ = sum; - } - } - - return O; -} - /* %!test %! b = [0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18]; @@ -270,7 +43,7 @@ %! b = single([0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18]); %! assert(conv2(single([0,1;1,2]),single([1,2,3;4,5,6;7,8,9])),b); -%!assert (conv2 (1:2, 1:3, [1,2;3,4;5,6]), +%!assert (conv2 (1:3, 1:2, [1,2;3,4;5,6]), %! [1,4,4;5,18,16;14,48,40;19,62,48;15,48,36;]); */ @@ -302,7 +75,7 @@ int nargin = args.length (); std::string shape = "full"; //default bool separable = false; - Shape ishape; + convn_type ct; if (nargin < 2) { @@ -323,11 +96,11 @@ } if (shape == "full") - ishape = SHAPE_FULL; + ct = convn_full; else if (shape == "same") - ishape = SHAPE_SAME; + ct = convn_same; else if (shape == "valid") - ishape = SHAPE_VALID; + ct = convn_valid; else { error ("conv2: shape type not valid"); @@ -354,21 +127,26 @@ || args(1).is_complex_type () || args(2).is_complex_type ()) { - FloatComplexColumnVector v1 (args(0).float_complex_vector_value ()); - FloatComplexColumnVector v2 (args(1).float_complex_vector_value ()); FloatComplexMatrix a (args(2).float_complex_matrix_value ()); - FloatComplexMatrix c (conv2 (v1, v2, a, ishape)); - if (! error_state) - retval = c; + if (args(1).is_real_type () && args(2).is_real_type ()) + { + FloatColumnVector v1 (args(0).float_vector_value ()); + FloatRowVector v2 (args(1).float_vector_value ()); + retval = convn (a, v1, v2, ct); + } + else + { + FloatComplexColumnVector v1 (args(0).float_complex_vector_value ()); + FloatComplexRowVector v2 (args(1).float_complex_vector_value ()); + retval = convn (a, v1, v2, ct); + } } else { FloatColumnVector v1 (args(0).float_vector_value ()); - FloatColumnVector v2 (args(1).float_vector_value ()); + FloatRowVector v2 (args(1).float_vector_value ()); FloatMatrix a (args(2).float_matrix_value ()); - FloatMatrix c (conv2 (v1, v2, a, ishape)); - if (! error_state) - retval = c; + retval = convn (a, v1, v2, ct); } } else @@ -377,21 +155,26 @@ || args(1).is_complex_type () || args(2).is_complex_type ()) { - ComplexColumnVector v1 (args(0).complex_vector_value ()); - ComplexColumnVector v2 (args(1).complex_vector_value ()); ComplexMatrix a (args(2).complex_matrix_value ()); - ComplexMatrix c (conv2 (v1, v2, a, ishape)); - if (! error_state) - retval = c; + if (args(1).is_real_type () && args(2).is_real_type ()) + { + ColumnVector v1 (args(0).vector_value ()); + RowVector v2 (args(1).vector_value ()); + retval = convn (a, v1, v2, ct); + } + else + { + ComplexColumnVector v1 (args(0).complex_vector_value ()); + ComplexRowVector v2 (args(1).complex_vector_value ()); + retval = convn (a, v1, v2, ct); + } } else { ColumnVector v1 (args(0).vector_value ()); - ColumnVector v2 (args(1).vector_value ()); + RowVector v2 (args(1).vector_value ()); Matrix a (args(2).matrix_value ()); - Matrix c (conv2 (v1, v2, a, ishape)); - if (! error_state) - retval = c; + retval = convn (a, v1, v2, ct); } } } // if (separable) @@ -404,18 +187,22 @@ || args(1).is_complex_type ()) { FloatComplexMatrix a (args(0).float_complex_matrix_value ()); - FloatComplexMatrix b (args(1).float_complex_matrix_value ()); - FloatComplexMatrix c (conv2 (a, b, ishape)); - if (! error_state) - retval = c; + if (args(1).is_real_type ()) + { + FloatMatrix b (args(1).float_matrix_value ()); + retval = convn (a, b, ct); + } + else + { + FloatComplexMatrix b (args(1).float_complex_matrix_value ()); + retval = convn (a, b, ct); + } } else { FloatMatrix a (args(0).float_matrix_value ()); FloatMatrix b (args(1).float_matrix_value ()); - FloatMatrix c (conv2 (a, b, ishape)); - if (! error_state) - retval = c; + retval = convn (a, b, ct); } } else @@ -424,18 +211,22 @@ || args(1).is_complex_type ()) { ComplexMatrix a (args(0).complex_matrix_value ()); - ComplexMatrix b (args(1).complex_matrix_value ()); - ComplexMatrix c (conv2 (a, b, ishape)); - if (! error_state) - retval = c; + if (args(1).is_real_type ()) + { + Matrix b (args(1).matrix_value ()); + retval = convn (a, b, ct); + } + else + { + ComplexMatrix b (args(1).complex_matrix_value ()); + retval = convn (a, b, ct); + } } else { Matrix a (args(0).matrix_value ()); Matrix b (args(1).matrix_value ()); - Matrix c (conv2 (a, b, ishape)); - if (! error_state) - retval = c; + retval = convn (a, b, ct); } } @@ -444,16 +235,107 @@ return retval; } -template MArray -conv2 (const MArray&, const MArray&, const MArray&, - Shape); +DEFUN_DLD (convn, args, , + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {y =} conv2 (@var{a}, @var{b}, @var{shape})\n\ +\n\ +Returns n-D convolution of @var{a} and @var{b} where the size\n\ +of @var{c} is given by\n\ +\n\ +@table @asis\n\ +@item @var{shape} = 'full'\n\ +returns full n-D convolution\n\ +@item @var{shape} = 'same'\n\ +same size as a. 'central' part of convolution\n\ +@item @var{shape} = 'valid'\n\ +only parts which do not include zero-padded edges\n\ +@end table\n\ +\n\ +By default @var{shape} is 'full'.\n\ +@end deftypefn") +{ + octave_value retval; + octave_value tmp; + int nargin = args.length (); + std::string shape = "full"; //default + bool separable = false; + convn_type ct; -template MArray -conv2 (const MArray&, const MArray&, Shape); + if (nargin < 2 || nargin > 3) + { + print_usage (); + return retval; + } + else if (nargin == 3) + { + if (args(2).is_string ()) + shape = args(2).string_value (); + else + separable = true; + } + + if (shape == "full") + ct = convn_full; + else if (shape == "same") + ct = convn_same; + else if (shape == "valid") + ct = convn_valid; + else + { + error ("convn: shape type not valid"); + print_usage (); + return retval; + } -template MArray -conv2 (const MArray&, const MArray&, - const MArray&, Shape); + if (args(0).is_single_type () || + args(1).is_single_type ()) + { + if (args(0).is_complex_type () + || args(1).is_complex_type ()) + { + FloatComplexNDArray a (args(0).float_complex_array_value ()); + if (args(1).is_real_type ()) + { + FloatNDArray b (args(1).float_array_value ()); + retval = convn (a, b, ct); + } + else + { + FloatComplexNDArray b (args(1).float_complex_array_value ()); + retval = convn (a, b, ct); + } + } + else + { + FloatNDArray a (args(0).float_array_value ()); + FloatNDArray b (args(1).float_array_value ()); + retval = convn (a, b, ct); + } + } + else + { + if (args(0).is_complex_type () + || args(1).is_complex_type ()) + { + ComplexNDArray a (args(0).complex_array_value ()); + if (args(1).is_real_type ()) + { + NDArray b (args(1).array_value ()); + retval = convn (a, b, ct); + } + else + { + ComplexNDArray b (args(1).complex_array_value ()); + retval = convn (a, b, ct); + } + } + else + { + NDArray a (args(0).array_value ()); + NDArray b (args(1).array_value ()); + retval = convn (a, b, ct); + } + } -template MArray -conv2 (const MArray&, const MArray&, Shape); + return retval; +}