changeset 46:095e26d93935

support complex numbers
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 26 May 2009 11:41:26 +0200
parents 3eb653452a38
children c83754e57d26 5867e925d0dd
files ChangeLog octave_to_python.cc python_to_octave.cc test/test.py
diffstat 4 files changed, 80 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/ChangeLog	Tue May 26 08:51:39 2009 +0200
+++ b/ChangeLog	Tue May 26 11:41:26 2009 +0200
@@ -1,3 +1,15 @@
+2009-05-26  Jaroslav Hajek  <highegg@gmail.com>
+
+	* octave_to_python.cc (octvalue_to_pyarrobj): Support Complex
+	and FloatComplex values.
+	(octvalue_to_pyobj): Support complex scalars.
+	* python_to_octave.cc (copy_pyarrobj_to_octarray_dispatch):
+	New template function.
+	(matching_type): New helper traits class.
+	(copy_pyarrobj_to_octarray_dispatch): Support complex types.
+	(pyarrobj_to_octvalue): Support complex scalars.
+	* test/test.py: Add tests for complex values.
+
 2009-05-26  Jaroslav Hajek  <highegg@gmail.com>
 
 	* octave_to_python.cc: New #include (boost/type_traits).
--- a/octave_to_python.cc	Tue May 26 08:51:39 2009 +0200
+++ b/octave_to_python.cc	Tue May 26 11:41:26 2009 +0200
@@ -1,5 +1,6 @@
 /*
  *  Copyright 2008 David Grundberg, Håkan Fors Nilsson
+ *  Copyright 2009 VZLU Prague
  *
  *  This file is part of Pytave.
  *
@@ -155,13 +156,11 @@
    }
 
    static PyArrayObject *octvalue_to_pyarrobj(const octave_value &matrix) {
-      if (matrix.is_complex_type ()) {
-            throw value_convert_exception(
-               "Complex Octave matrices conversion not implemented");
-      }
-
       if (matrix.is_double_type ()) {
-         if (matrix.is_real_type()) {
+         if (matrix.is_complex_type ()) {
+            return create_array<Complex, ComplexNDArray>
+               (matrix.complex_array_value(), PyArray_CDOUBLE);
+         } else if (matrix.is_real_type()) {
             return create_array<double, NDArray>(matrix.array_value(),
                                                  PyArray_DOUBLE);
          } else
@@ -170,7 +169,10 @@
 
 #ifdef PYTAVE_USE_OCTAVE_FLOATS
       if (matrix.is_single_type ()) {
-         if (matrix.is_real_type()) {
+         if (matrix.is_complex_type ()) {
+            return create_array<FloatComplex, FloatComplexNDArray>
+               (matrix.float_complex_array_value(), PyArray_CFLOAT);
+         } else if (matrix.is_real_type()) {
             return create_array<float, FloatNDArray>(
                matrix.float_array_value(), PyArray_FLOAT);
          } else
@@ -263,6 +265,8 @@
             py_object = object(octvalue.bool_value());
          else if (octvalue.is_real_scalar())
             py_object = object(octvalue.double_value());
+         else if (octvalue.is_complex_scalar())
+            py_object = object(octvalue.complex_value());
          else if (octvalue.is_integer_type())
             py_object = object(octvalue.int_value());
          else
--- a/python_to_octave.cc	Tue May 26 08:51:39 2009 +0200
+++ b/python_to_octave.cc	Tue May 26 11:41:26 2009 +0200
@@ -1,5 +1,6 @@
 /*
  *  Copyright 2008 David Grundberg, Håkan Fors Nilsson
+ *  Copyright 2009 VZLU Prague
  *
  *  This file is part of Pytave.
  *
@@ -21,6 +22,7 @@
 #include <boost/python.hpp>
 #include <boost/python/numeric.hpp>
 #include "arrayobjectdefs.h"
+#include <boost/type_traits/integral_constant.hpp>
 #undef HAVE_STAT /* both boost.python and octave defines HAVE_STAT... */
 #include <octave/oct.h>
 #include <octave/oct-map.h>
@@ -67,13 +69,43 @@
       }
    }
 
+   template <class PythonPrimitive, class OctaveBase>
+   static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix,
+                                       const PyArrayObject* const pyarr,
+                                       const boost::true_type&) {
+      copy_pyarrobj_to_octarray<PythonPrimitive, OctaveBase>
+         (matrix, pyarr, 0, 1, 0, 0);
+   }
+
+   template <class PythonPrimitive, class OctaveBase>
+   static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix,
+                                       const PyArrayObject* const pyarr,
+                                       const boost::false_type&) {
+      assert(0);
+   }
+
+   template <class X, class Y> class matching_type : public boost::false_type { };
+   template <class X> class matching_type<X, X> : public boost::true_type { };
+   template <class X> class matching_type<X, octave_int<X> > : public boost::true_type { };
+#ifndef PYTAVE_USE_OCTAVE_FLOATS
+   template <> class matching_type<float, double> : public boost::true_type { };
+   template <> class matching_type<FloatComplex, Complex> : public boost::true_type { };
+#endif
+
+   template <class PythonPrimitive, class OctaveBase>
+   static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix,
+                                       const PyArrayObject* const pyarr) {
+      matching_type<PythonPrimitive, typename OctaveBase::element_type> inst;
+      copy_pyarrobj_to_octarray_dispatch<PythonPrimitive, OctaveBase> (matrix, pyarr, inst);
+   }
+
    template <class OctaveBase>
    static void copy_pyarrobj_to_octarray_boot(OctaveBase &matrix,
                                        const PyArrayObject* const pyarr) {
 
 #define ARRAYCASE(AC_pyarrtype, AC_primitive) case AC_pyarrtype: \
-         copy_pyarrobj_to_octarray<AC_primitive, OctaveBase>\
-         (matrix, pyarr, 0, 1, 0, 0); \
+         copy_pyarrobj_to_octarray_dispatch<AC_primitive, OctaveBase>\
+         (matrix, pyarr); \
          break; \
 
       switch (pyarr->descr->type_num) {
@@ -91,8 +123,12 @@
 
          /* Commonly Numeric.array(..., Numeric.Float) */
          ARRAYCASE(PyArray_DOUBLE, double)
-//         ARRAYCASE(PyArray_CFLOAT, )
-//         ARRAYCASE(PyArray_CDOUBLE, )
+
+         /* Commonly Numeric.array(..., Numeric.Complex32) */
+         ARRAYCASE(PyArray_CFLOAT, FloatComplex)
+
+         /* Commonly Numeric.array(..., Numeric.Complex) */
+         ARRAYCASE(PyArray_CDOUBLE, Complex)
 //         ARRAYCASE(PyArray_OBJECT, )
          default:
             throw object_convert_exception(
@@ -177,6 +213,16 @@
          case PyArray_DOUBLE:
             pyarrobj_to_octvalueNd<NDArray>(octvalue, pyarr, dims);
             break;
+         case PyArray_CFLOAT:
+#ifdef PYTAVE_USE_OCTAVE_FLOATS
+            pyarrobj_to_octvalueNd<FloatComplexNDArray>(octvalue, pyarr, dims);
+            break;
+#else
+            /* fallthrough */
+#endif
+         case PyArray_CDOUBLE:
+            pyarrobj_to_octvalueNd<ComplexNDArray>(octvalue, pyarr, dims);
+            break;
          default:
             throw object_convert_exception(
                PyEval_GetFuncDesc((PyObject*)(pyarr)) + string(" ")
@@ -304,6 +350,7 @@
                           const boost::python::object &py_object) {
       extract<int> intx(py_object);
       extract<double> doublex(py_object);
+      extract<Complex> complexx(py_object);
       extract<string> stringx(py_object);
       extract<numeric::array> arrayx(py_object);
       extract<boost::python::list> listx(py_object);
@@ -312,6 +359,8 @@
          oct_value = intx();
       } else if (doublex.check()) {
          oct_value = doublex();
+      } else if (complexx.check()) {
+         oct_value = complexx();
       } else if (arrayx.check()) {
          pyarr_to_octvalue(oct_value, (PyArrayObject*)py_object.ptr());
       } else if (stringx.check()) {
--- a/test/test.py	Tue May 26 08:51:39 2009 +0200
+++ b/test/test.py	Tue May 26 11:41:26 2009 +0200
@@ -20,6 +20,8 @@
 arr2f = Numeric.array([[1.32, 2, 3, 4],[5,6,7,8]], Numeric.Float32)
 arr2d = Numeric.array([[1.17, 2, 3, 4],[5,6,7,8]], Numeric.Float)
 arr3f = Numeric.array([[[1.32, 2, 3, 4],[5,6,7,8]],[[9, 10, 11, 12],[13,14,15,16]]], Numeric.Float32)
+arr1c = Numeric.array([[1+2j, 3+4j, 5+6j, 7+0.5j]], Numeric.Complex)
+arr1fc = Numeric.array([[1+2j, 3+4j, 5+6j, 7+0.5j]], Numeric.Complex32)
 
 alimit_int32 = Numeric.array([[-2147483648, 2147483647]], Numeric.Int32);
 alimit_int16 = Numeric.array([[-32768, 32767, -32769, 32768]], Numeric.Int16);
@@ -182,7 +184,8 @@
 testmatrix(arr1fT2)
 testmatrix(arr1i)
 testmatrix(arr1b)
-testmatrix(arr1i32)
+testmatrix(arr1c)
+testmatrix(arr1fc)
 
 # 2d arrays
 testmatrix(arr2f)