changeset 10384:978f5c94b11f

initial implementation of conv2/convn in liboctave
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 02 Mar 2010 10:04:12 +0100
parents 796258b94b22
children 56116dceb1e0
files liboctave/ChangeLog liboctave/Makefile.am src/ChangeLog src/DLD-FUNCTIONS/conv2.cc
diffstat 4 files changed, 170 insertions(+), 276 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Tue Mar 02 01:14:11 2010 +0100
+++ b/liboctave/ChangeLog	Tue Mar 02 10:04:12 2010 +0100
@@ -1,3 +1,8 @@
+2010-03-02  Jaroslav Hajek  <highegg@gmail.com>
+
+	* oct-convn.h, oct-convn.cc: New sources.
+	* Makefile.am: Include them.
+
 2010-03-01  David Bateman  <dbateman@free.fr>
 
 	* Sparse.cc (Sparse<T>::maybe_delete_elements (idxx_vector&)):
--- a/liboctave/Makefile.am	Tue Mar 02 01:14:11 2010 +0100
+++ b/liboctave/Makefile.am	Tue Mar 02 10:04:12 2010 +0100
@@ -211,6 +211,7 @@
   mach-info.h \
   oct-alloc.h \
   oct-cmplx.h \
+  oct-convn.h \
   oct-env.h \
   oct-fftw.h \
   oct-group.h \
@@ -425,6 +426,7 @@
   lo-utils.cc \
   mach-info.cc \
   oct-alloc.cc \
+  oct-convn.cc \
   oct-env.cc \
   oct-fftw.cc \
   oct-glob.cc \
--- a/src/ChangeLog	Tue Mar 02 01:14:11 2010 +0100
+++ b/src/ChangeLog	Tue Mar 02 10:04:12 2010 +0100
@@ -1,3 +1,8 @@
+2010-03-02  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/conv.cc (Fconv2): Rewrite using convn from liboctave.
+	(Fconvn): New DEFUN.
+
 2010-03-01  John W. Eaton  <jwe@octave.org>
 
 	* (str2double1): Pass argument as const reference and make
--- a/src/DLD-FUNCTIONS/conv2.cc	Tue Mar 02 01:14:11 2010 +0100
+++ b/src/DLD-FUNCTIONS/conv2.cc	Tue Mar 02 10:04:12 2010 +0100
@@ -25,6 +25,8 @@
 #include <config.h>
 #endif
 
+#include "oct-convn.h"
+
 #include "defun-dld.h"
 #include "error.h"
 #include "oct-obj.h"
@@ -32,235 +34,6 @@
 
 enum Shape { SHAPE_FULL, SHAPE_SAME, SHAPE_VALID };
 
-#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
-extern MArray<double>
-conv2 (const MArray<double>&, const MArray<double>&, const MArray<double>&,
-       Shape);
-
-extern MArray<Complex>
-conv2 (const MArray<Complex>&, const MArray<Complex>&,
-       const MArray<Complex>&, Shape);
-
-extern MArray<float>
-conv2 (const MArray<float>&, const MArray<float>&, const MArray<float>&,
-       Shape);
-
-extern MArray<FloatComplex>
-conv2 (const MArray<FloatComplex>&, const MArray<FloatComplex>&,
-       const MArray<FloatComplex>&, Shape);
-#endif
-
-template <class T>
-MArray<T>
-conv2 (const MArray<T>& R, const MArray<T>& C, const MArray<T>& A, Shape ishape)
-{
-  octave_idx_type  Rn = R.length ();
-  octave_idx_type  Cm = C.length ();
-  octave_idx_type  Am = A.rows ();
-  octave_idx_type  An = A.columns ();
-
-  // Calculate the size of the output matrix:
-  // in order to stay Matlab compatible, it is based
-  // on the third parameter if it's separable, and the
-  // first if it's not
-
-  octave_idx_type outM = 0;
-  octave_idx_type outN = 0;
-  octave_idx_type edgM = 0;
-  octave_idx_type edgN = 0;
-
-  switch (ishape)
-    {
-      case SHAPE_FULL:
-        outM = Am + Cm - 1;
-        outN = An + Rn - 1;
-        edgM = Cm - 1;
-        edgN = Rn - 1;
-        break;
-
-      case SHAPE_SAME:
-        outM = Am;
-        outN = An;
-        // Follow the Matlab convention (ie + instead of -)
-        edgM = (Cm - 1) /2;
-        edgN = (Rn - 1) /2;
-        break;
-
-      case SHAPE_VALID:
-        outM = Am - Cm + 1;
-        outN = An - Rn + 1;
-        if (outM < 0)
-          outM = 0;
-        if (outN < 0)
-          outN = 0;
-        edgM = edgN = 0;
-        break;
-
-      default:
-        error ("conv2: invalid value of parameter ishape");
-    }
-
-  MArray<T> O (outM, outN);
-
-  T *Od0 = O.fortran_vec ();
-
-  // X accumulates the 1-D conv for each row, before calculating
-  // the convolution in the other direction.  There is no efficiency
-  // advantage to doing it in either direction first.
-
-  MArray<T> X (An, 1);
-
-  T *Xd0 = X.fortran_vec ();
-
-  for (octave_idx_type oi = 0; oi < outM; oi++)
-    {
-      T *Xd = Xd0;
-
-      octave_idx_type ci0 = Cm - 1 - std::max (0, edgM-oi);
-      octave_idx_type ai0 = std::max (0, oi-edgM);
-
-      const T *Cd0 = C.data () + ci0;
-      const T *Ad0 = A.data () + ai0;
-
-      for (octave_idx_type oj = 0; oj < An; oj++)
-        {
-          T sum = 0;
-
-          const T* Cd = Cd0;
-          const T* Ad = Ad0 + Am*oj;
-
-          for (octave_idx_type ci = ci0, ai = ai0; ci >= 0 && ai < Am;
-               ci--, ai++)
-            sum += (*Ad++) * (*Cd--);
-
-          *Xd++ = sum;
-        }
-
-      T *Od = Od0 + oi;
-
-      const T *Rd0 = R.data ();
-
-      for (octave_idx_type oj = 0; oj < outN; oj++)
-        {
-          T sum = 0;
-
-          octave_idx_type rj = Rn - 1 - std::max (0, edgN-oj);
-          octave_idx_type aj = std::max (0, oj-edgN);
-
-          Xd = Xd0 + aj;
-          const T* Rd = Rd0 + rj;
-
-          for ( ; rj >= 0 && aj < An; rj--, aj++)
-            sum += (*Xd++) * (*Rd--);
-
-          *Od = sum;
-          Od += outM;
-        }
-    }
-
-  return O;
-}
-
-#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
-extern MArray<double>
-conv2 (MArray<double>&, MArray<double>&, Shape);
-
-extern MArray<Complex>
-conv2 (MArray<Complex>&, MArray<Complex>&, Shape);
-
-extern MArray<float>
-conv2 (MArray<float>&, MArray<float>&, Shape);
-
-extern MArray<FloatComplex>
-conv2 (MArray<FloatComplex>&, MArray<FloatComplex>&, Shape);
-#endif
-
-template <class T>
-MArray<T>
-conv2 (const MArray<T>& A, const MArray<T>& B, Shape ishape)
-{
-  // Convolution works fastest if we choose the A matrix to be
-  // the largest.
-
-  // Here we calculate the size of the output matrix,
-  // in order to stay Matlab compatible, it is based
-  // on the third parameter if it's separable, and the
-  // first if it's not
-
-  // NOTE in order to be Matlab compatible, we give argueably
-  // wrong sizes for 'valid' if the smallest matrix is first
-
-  octave_idx_type Am = A.rows ();
-  octave_idx_type An = A.columns ();
-  octave_idx_type Bm = B.rows ();
-  octave_idx_type Bn = B.columns ();
-
-  octave_idx_type outM = 0;
-  octave_idx_type outN = 0;
-  octave_idx_type edgM = 0;
-  octave_idx_type edgN = 0;
-
-  switch (ishape)
-    {
-      case SHAPE_FULL:
-        outM = Am + Bm - 1;
-        outN = An + Bn - 1;
-        edgM = Bm - 1;
-        edgN = Bn - 1;
-        break;
-
-      case SHAPE_SAME:
-        outM = Am;
-        outN = An;
-        edgM = (Bm - 1) /2;
-        edgN = (Bn - 1) /2;
-        break;
-
-      case SHAPE_VALID:
-        outM = Am - Bm + 1;
-        outN = An - Bn + 1;
-        if (outM < 0)
-          outM = 0;
-        if (outN < 0)
-          outN = 0;
-        edgM = edgN = 0;
-        break;
-    }
-
-  MArray<T> O (outM, outN);
-
-  T *Od = O.fortran_vec ();
-
-  for (octave_idx_type oj = 0; oj < outN; oj++)
-    {
-      octave_idx_type aj0 = std::max (0, oj-edgN);
-      octave_idx_type bj0 = Bn - 1 - std::max (0, edgN-oj);
-
-      for (octave_idx_type oi = 0; oi < outM; oi++)
-        {
-          T sum = 0;
-
-          octave_idx_type bi0 = Bm - 1 - std::max (0, edgM-oi);
-          octave_idx_type ai0 = std::max (0, oi-edgM);
-
-          for (octave_idx_type aj = aj0, bj = bj0; bj >= 0 && aj < An;
-               bj--, aj++)
-            {
-              const T* Ad = A.data () + ai0 + Am*aj;
-              const T* Bd = B.data () + bi0 + Bm*bj;
-
-              for (octave_idx_type ai = ai0, bi = bi0; bi >= 0 && ai < Am;
-                   bi--, ai++)
-                sum += (*Ad++) * (*Bd--);
-            }
-
-          *Od++ = sum;
-        }
-    }
-
-  return O;
-}
-
 /*
 %!test
 %! b = [0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18];
@@ -270,7 +43,7 @@
 %! b = single([0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18]);
 %! assert(conv2(single([0,1;1,2]),single([1,2,3;4,5,6;7,8,9])),b);
 
-%!assert (conv2 (1:2, 1:3, [1,2;3,4;5,6]),
+%!assert (conv2 (1:3, 1:2, [1,2;3,4;5,6]),
 %!        [1,4,4;5,18,16;14,48,40;19,62,48;15,48,36;]);
 */
 
@@ -302,7 +75,7 @@
   int nargin = args.length ();
   std::string shape = "full"; //default
   bool separable = false;
-  Shape ishape;
+  convn_type ct;
 
   if (nargin < 2)
     {
@@ -323,11 +96,11 @@
     }
 
   if (shape == "full")
-    ishape = SHAPE_FULL;
+    ct = convn_full;
   else if (shape == "same")
-    ishape = SHAPE_SAME;
+    ct = convn_same;
   else if (shape == "valid")
-    ishape = SHAPE_VALID;
+    ct = convn_valid;
   else
     {
       error ("conv2: shape type not valid");
@@ -354,21 +127,26 @@
                || args(1).is_complex_type ()
                || args(2).is_complex_type ())
              {
-               FloatComplexColumnVector v1 (args(0).float_complex_vector_value ());
-               FloatComplexColumnVector v2 (args(1).float_complex_vector_value ());
                FloatComplexMatrix a (args(2).float_complex_matrix_value ());
-               FloatComplexMatrix c (conv2 (v1, v2, a, ishape));
-               if (! error_state)
-                 retval = c;
+               if (args(1).is_real_type () && args(2).is_real_type ())
+                 {
+                   FloatColumnVector v1 (args(0).float_vector_value ());
+                   FloatRowVector v2 (args(1).float_vector_value ());
+                   retval = convn (a, v1, v2, ct);
+                 }
+               else
+                 {
+                   FloatComplexColumnVector v1 (args(0).float_complex_vector_value ());
+                   FloatComplexRowVector v2 (args(1).float_complex_vector_value ());
+                   retval = convn (a, v1, v2, ct);
+                 }
              }
            else
              {
                FloatColumnVector v1 (args(0).float_vector_value ());
-               FloatColumnVector v2 (args(1).float_vector_value ());
+               FloatRowVector v2 (args(1).float_vector_value ());
                FloatMatrix a (args(2).float_matrix_value ());
-               FloatMatrix c (conv2 (v1, v2, a, ishape));
-               if (! error_state)
-                 retval = c;
+               retval = convn (a, v1, v2, ct);
              }
          }
        else
@@ -377,21 +155,26 @@
                || args(1).is_complex_type ()
                || args(2).is_complex_type ())
              {
-               ComplexColumnVector v1 (args(0).complex_vector_value ());
-               ComplexColumnVector v2 (args(1).complex_vector_value ());
                ComplexMatrix a (args(2).complex_matrix_value ());
-               ComplexMatrix c (conv2 (v1, v2, a, ishape));
-               if (! error_state)
-                 retval = c;
+               if (args(1).is_real_type () && args(2).is_real_type ())
+                 {
+                   ColumnVector v1 (args(0).vector_value ());
+                   RowVector v2 (args(1).vector_value ());
+                   retval = convn (a, v1, v2, ct);
+                 }
+               else
+                 {
+                   ComplexColumnVector v1 (args(0).complex_vector_value ());
+                   ComplexRowVector v2 (args(1).complex_vector_value ());
+                   retval = convn (a, v1, v2, ct);
+                 }
              }
            else
              {
                ColumnVector v1 (args(0).vector_value ());
-               ColumnVector v2 (args(1).vector_value ());
+               RowVector v2 (args(1).vector_value ());
                Matrix a (args(2).matrix_value ());
-               Matrix c (conv2 (v1, v2, a, ishape));
-               if (! error_state)
-                 retval = c;
+               retval = convn (a, v1, v2, ct);
              }
          }
      } // if (separable)
@@ -404,18 +187,22 @@
                || args(1).is_complex_type ())
              {
                FloatComplexMatrix a (args(0).float_complex_matrix_value ());
-               FloatComplexMatrix b (args(1).float_complex_matrix_value ());
-               FloatComplexMatrix c (conv2 (a, b, ishape));
-               if (! error_state)
-                 retval = c;
+               if (args(1).is_real_type ())
+                 {
+                   FloatMatrix b (args(1).float_matrix_value ());
+                   retval = convn (a, b, ct);
+                 }
+               else
+                 {
+                   FloatComplexMatrix b (args(1).float_complex_matrix_value ());
+                   retval = convn (a, b, ct);
+                 }
              }
            else
              {
                FloatMatrix a (args(0).float_matrix_value ());
                FloatMatrix b (args(1).float_matrix_value ());
-               FloatMatrix c (conv2 (a, b, ishape));
-               if (! error_state)
-                 retval = c;
+               retval = convn (a, b, ct);
              }
          }
        else
@@ -424,18 +211,22 @@
                || args(1).is_complex_type ())
              {
                ComplexMatrix a (args(0).complex_matrix_value ());
-               ComplexMatrix b (args(1).complex_matrix_value ());
-               ComplexMatrix c (conv2 (a, b, ishape));
-               if (! error_state)
-                 retval = c;
+               if (args(1).is_real_type ())
+                 {
+                   Matrix b (args(1).matrix_value ());
+                   retval = convn (a, b, ct);
+                 }
+               else
+                 {
+                   ComplexMatrix b (args(1).complex_matrix_value ());
+                   retval = convn (a, b, ct);
+                 }
              }
            else
              {
                Matrix a (args(0).matrix_value ());
                Matrix b (args(1).matrix_value ());
-               Matrix c (conv2 (a, b, ishape));
-               if (! error_state)
-                 retval = c;
+               retval = convn (a, b, ct);
              }
          }
 
@@ -444,16 +235,107 @@
    return retval;
 }
 
-template MArray<double>
-conv2 (const MArray<double>&, const MArray<double>&, const MArray<double>&,
-       Shape);
+DEFUN_DLD (convn, args, ,
+  "-*- texinfo -*-\n\
+@deftypefn {Loadable Function} {y =} conv2 (@var{a}, @var{b}, @var{shape})\n\
+\n\
+Returns n-D convolution of @var{a} and @var{b} where the size\n\
+of @var{c} is given by\n\
+\n\
+@table @asis\n\
+@item @var{shape} = 'full'\n\
+returns full n-D convolution\n\
+@item @var{shape} = 'same'\n\
+same size as a. 'central' part of convolution\n\
+@item @var{shape} = 'valid'\n\
+only parts which do not include zero-padded edges\n\
+@end table\n\
+\n\
+By default @var{shape} is 'full'.\n\
+@end deftypefn")
+{
+  octave_value retval;
+  octave_value tmp;
+  int nargin = args.length ();
+  std::string shape = "full"; //default
+  bool separable = false;
+  convn_type ct;
 
-template MArray<double>
-conv2 (const MArray<double>&, const MArray<double>&, Shape);
+  if (nargin < 2 || nargin > 3)
+    {
+     print_usage ();
+     return retval;
+    }
+  else if (nargin == 3)
+    {
+      if (args(2).is_string ())
+        shape = args(2).string_value ();
+      else
+        separable = true;
+    }
+
+  if (shape == "full")
+    ct = convn_full;
+  else if (shape == "same")
+    ct = convn_same;
+  else if (shape == "valid")
+    ct = convn_valid;
+  else
+    {
+      error ("convn: shape type not valid");
+      print_usage ();
+      return retval;
+    }
 
-template MArray<Complex>
-conv2 (const MArray<Complex>&, const MArray<Complex>&,
-       const MArray<Complex>&, Shape);
+  if (args(0).is_single_type () || 
+      args(1).is_single_type ())
+    {
+      if (args(0).is_complex_type ()
+          || args(1).is_complex_type ())
+        {
+          FloatComplexNDArray a (args(0).float_complex_array_value ());
+          if (args(1).is_real_type ())
+            {
+              FloatNDArray b (args(1).float_array_value ());
+              retval = convn (a, b, ct);
+            }
+          else
+            {
+              FloatComplexNDArray b (args(1).float_complex_array_value ());
+              retval = convn (a, b, ct);
+            }
+        }
+      else
+        {
+          FloatNDArray a (args(0).float_array_value ());
+          FloatNDArray b (args(1).float_array_value ());
+          retval = convn (a, b, ct);
+        }
+    }
+  else
+    {
+      if (args(0).is_complex_type ()
+          || args(1).is_complex_type ())
+        {
+          ComplexNDArray a (args(0).complex_array_value ());
+          if (args(1).is_real_type ())
+            {
+              NDArray b (args(1).array_value ());
+              retval = convn (a, b, ct);
+            }
+          else
+            {
+              ComplexNDArray b (args(1).complex_array_value ());
+              retval = convn (a, b, ct);
+            }
+        }
+      else
+        {
+          NDArray a (args(0).array_value ());
+          NDArray b (args(1).array_value ());
+          retval = convn (a, b, ct);
+        }
+    }
 
-template MArray<Complex>
-conv2 (const MArray<Complex>&, const MArray<Complex>&, Shape);
+   return retval;
+}