changeset 7660:5f6e11567f70

Allow convolving real data with complex data
author sh@sh-laptop
date Thu, 27 Mar 2008 16:15:36 -0400
parents 4ab2488ab2b4
children f3493c40a0bd
files src/ChangeLog src/DLD-FUNCTIONS/__convn__.cc
diffstat 2 files changed, 79 insertions(+), 20 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Thu Mar 27 15:17:53 2008 -0400
+++ b/src/ChangeLog	Thu Mar 27 16:15:36 2008 -0400
@@ -1,3 +1,13 @@
+2008-03-27  John W. Eaton  <jwe@octave.org>
+
+	* DLD-FUNCTIONS/__convn__.cc (convn): Use traits class and
+	typedefs to allow all types to be deduced from argument types.
+
+2008-03-27  Soren Hauberg  <hauberg@gmail.com>
+
+	* DLD-FUNCTIONS/__convn__.cc (Fconvn): Allow convolving real data with
+	complex data.
+
 2008-03-26  John W. Eaton  <jwe@octave.org>
 
 	* ov-range.h (octave_range::subsref (const std::string&,
--- a/src/DLD-FUNCTIONS/__convn__.cc	Thu Mar 27 15:17:53 2008 -0400
+++ b/src/DLD-FUNCTIONS/__convn__.cc	Thu Mar 27 16:15:36 2008 -0400
@@ -31,10 +31,32 @@
 
 #include "defun-dld.h"
 
+template <class T1, class T2>
+class
+octave_convn_traits
+{
+public:
+  // The return type for a T1 by T2 convn operation.
+  typedef T1 TR;
+};
+
+#define OCTAVE_CONVN_TRAIT(T1, T2, T3) \
+  template<> \
+  class octave_convn_traits <T1, T2> \
+  { \
+  public: \
+    typedef T3 TR; \
+  }
+
+OCTAVE_CONVN_TRAIT (NDArray, NDArray, NDArray);
+OCTAVE_CONVN_TRAIT (ComplexNDArray, NDArray, ComplexNDArray);
+OCTAVE_CONVN_TRAIT (NDArray, ComplexNDArray, ComplexNDArray);
+OCTAVE_CONVN_TRAIT (ComplexNDArray, ComplexNDArray, ComplexNDArray);
+
 // FIXME -- this function should maybe be available in liboctave?
-template <class MT, class ST> 
+template <class MTa, class MTb> 
 octave_value
-convn (const MT& a, const MT& b)
+convn (const MTa& a, const MTb& b)
 {
   octave_value retval;
 
@@ -56,7 +78,9 @@
   for (octave_idx_type n = 0; n < ndims; n++)
     out_size(n) = std::max (a_size(n) - b_size(n) + 1, 0);
 
-  MT out = MT (out_size);
+  typedef typename octave_convn_traits<MTa, MTb>::TR MTout;
+
+  MTout out (out_size);
 
   const octave_idx_type out_numel = out.numel ();
   
@@ -72,7 +96,7 @@
       OCTAVE_QUIT;
 
       // For each neighbour
-      ST sum = 0;
+      typename MTout::element_type sum = 0;
 
       for (octave_idx_type n = 0; n < ndims; n++)
         b_idx(n) = 0;
@@ -108,24 +132,49 @@
 
   if (args.length () == 2)
     {
-      if (args(0).is_real_type () && args(1).is_real_type ())
-        {
-          const NDArray a = args (0).array_value ();
-          const NDArray b = args (1).array_value ();
+      if (args(0).is_real_type ())
+	{
+	  if (args(1).is_real_type ())
+	    {
+	      const NDArray a = args (0).array_value ();
+	      const NDArray b = args (1).array_value ();
+
+	      if (! error_state)
+		retval = convn (a, b);
+	    }
+	  else if (args(1).is_complex_type ())
+	    {
+	      const NDArray a = args (0).array_value ();
+	      const ComplexNDArray b = args (1).complex_array_value ();
 
-	  if (! error_state)
-	    retval = convn<NDArray, double> (a, b);
-        }
-      else if (args(0).is_complex_type () && args(1).is_complex_type ())
-        {
-          const ComplexNDArray a = args (0).complex_array_value ();
-          const ComplexNDArray b = args (1).complex_array_value ();
+	      if (! error_state)
+		retval = convn (a, b);
+	    }
+	  else
+	    error ("__convn__: invalid call");
+	}
+      else if (args(0).is_complex_type ())
+	{
+	  if (args(1).is_complex_type ())
+	    {
+	      const ComplexNDArray a = args (0).complex_array_value ();
+	      const ComplexNDArray b = args (1).complex_array_value ();
 
-	  if (! error_state)
-	    retval = convn<ComplexNDArray, Complex> (a, b);
-        }
-      else
-	error ("__convn__: first and second input should be real, or complex arrays");
+	      if (! error_state)
+		retval = convn (a, b);
+	    }
+	  else if (args(1).is_real_type ())
+	    {
+	      const ComplexNDArray a = args (0).complex_array_value ();
+	      const NDArray b = args (1).array_value ();
+
+	      if (! error_state)
+		retval = convn (a, b);
+	    }
+	  else
+	    error ("__convn__: invalid call");
+	}
+      error ("__convn__: invalid call");
     }
   else
     print_usage ();