Mercurial > pytave
changeset 46:095e26d93935
support complex numbers
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 26 May 2009 11:41:26 +0200 |
parents | 3eb653452a38 |
children | c83754e57d26 5867e925d0dd |
files | ChangeLog octave_to_python.cc python_to_octave.cc test/test.py |
diffstat | 4 files changed, 80 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/ChangeLog Tue May 26 08:51:39 2009 +0200 +++ b/ChangeLog Tue May 26 11:41:26 2009 +0200 @@ -1,3 +1,15 @@ +2009-05-26 Jaroslav Hajek <highegg@gmail.com> + + * octave_to_python.cc (octvalue_to_pyarrobj): Support Complex + and FloatComplex values. + (octvalue_to_pyobj): Support complex scalars. + * python_to_octave.cc (copy_pyarrobj_to_octarray_dispatch): + New template function. + (matching_type): New helper traits class. + (copy_pyarrobj_to_octarray_dispatch): Support complex types. + (pyarrobj_to_octvalue): Support complex scalars. + * test/test.py: Add tests for complex values. + 2009-05-26 Jaroslav Hajek <highegg@gmail.com> * octave_to_python.cc: New #include (boost/type_traits).
--- a/octave_to_python.cc Tue May 26 08:51:39 2009 +0200 +++ b/octave_to_python.cc Tue May 26 11:41:26 2009 +0200 @@ -1,5 +1,6 @@ /* * Copyright 2008 David Grundberg, Håkan Fors Nilsson + * Copyright 2009 VZLU Prague * * This file is part of Pytave. * @@ -155,13 +156,11 @@ } static PyArrayObject *octvalue_to_pyarrobj(const octave_value &matrix) { - if (matrix.is_complex_type ()) { - throw value_convert_exception( - "Complex Octave matrices conversion not implemented"); - } - if (matrix.is_double_type ()) { - if (matrix.is_real_type()) { + if (matrix.is_complex_type ()) { + return create_array<Complex, ComplexNDArray> + (matrix.complex_array_value(), PyArray_CDOUBLE); + } else if (matrix.is_real_type()) { return create_array<double, NDArray>(matrix.array_value(), PyArray_DOUBLE); } else @@ -170,7 +169,10 @@ #ifdef PYTAVE_USE_OCTAVE_FLOATS if (matrix.is_single_type ()) { - if (matrix.is_real_type()) { + if (matrix.is_complex_type ()) { + return create_array<FloatComplex, FloatComplexNDArray> + (matrix.float_complex_array_value(), PyArray_CFLOAT); + } else if (matrix.is_real_type()) { return create_array<float, FloatNDArray>( matrix.float_array_value(), PyArray_FLOAT); } else @@ -263,6 +265,8 @@ py_object = object(octvalue.bool_value()); else if (octvalue.is_real_scalar()) py_object = object(octvalue.double_value()); + else if (octvalue.is_complex_scalar()) + py_object = object(octvalue.complex_value()); else if (octvalue.is_integer_type()) py_object = object(octvalue.int_value()); else
--- a/python_to_octave.cc Tue May 26 08:51:39 2009 +0200 +++ b/python_to_octave.cc Tue May 26 11:41:26 2009 +0200 @@ -1,5 +1,6 @@ /* * Copyright 2008 David Grundberg, Håkan Fors Nilsson + * Copyright 2009 VZLU Prague * * This file is part of Pytave. * @@ -21,6 +22,7 @@ #include <boost/python.hpp> #include <boost/python/numeric.hpp> #include "arrayobjectdefs.h" +#include <boost/type_traits/integral_constant.hpp> #undef HAVE_STAT /* both boost.python and octave defines HAVE_STAT... */ #include <octave/oct.h> #include <octave/oct-map.h> @@ -67,13 +69,43 @@ } } + template <class PythonPrimitive, class OctaveBase> + static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix, + const PyArrayObject* const pyarr, + const boost::true_type&) { + copy_pyarrobj_to_octarray<PythonPrimitive, OctaveBase> + (matrix, pyarr, 0, 1, 0, 0); + } + + template <class PythonPrimitive, class OctaveBase> + static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix, + const PyArrayObject* const pyarr, + const boost::false_type&) { + assert(0); + } + + template <class X, class Y> class matching_type : public boost::false_type { }; + template <class X> class matching_type<X, X> : public boost::true_type { }; + template <class X> class matching_type<X, octave_int<X> > : public boost::true_type { }; +#ifndef PYTAVE_USE_OCTAVE_FLOATS + template <> class matching_type<float, double> : public boost::true_type { }; + template <> class matching_type<FloatComplex, Complex> : public boost::true_type { }; +#endif + + template <class PythonPrimitive, class OctaveBase> + static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix, + const PyArrayObject* const pyarr) { + matching_type<PythonPrimitive, typename OctaveBase::element_type> inst; + copy_pyarrobj_to_octarray_dispatch<PythonPrimitive, OctaveBase> (matrix, pyarr, inst); + } + template <class OctaveBase> static void copy_pyarrobj_to_octarray_boot(OctaveBase &matrix, const PyArrayObject* const pyarr) { #define ARRAYCASE(AC_pyarrtype, AC_primitive) case AC_pyarrtype: \ - copy_pyarrobj_to_octarray<AC_primitive, OctaveBase>\ - (matrix, pyarr, 0, 1, 0, 0); \ + copy_pyarrobj_to_octarray_dispatch<AC_primitive, OctaveBase>\ + (matrix, pyarr); \ break; \ switch (pyarr->descr->type_num) { @@ -91,8 +123,12 @@ /* Commonly Numeric.array(..., Numeric.Float) */ ARRAYCASE(PyArray_DOUBLE, double) -// ARRAYCASE(PyArray_CFLOAT, ) -// ARRAYCASE(PyArray_CDOUBLE, ) + + /* Commonly Numeric.array(..., Numeric.Complex32) */ + ARRAYCASE(PyArray_CFLOAT, FloatComplex) + + /* Commonly Numeric.array(..., Numeric.Complex) */ + ARRAYCASE(PyArray_CDOUBLE, Complex) // ARRAYCASE(PyArray_OBJECT, ) default: throw object_convert_exception( @@ -177,6 +213,16 @@ case PyArray_DOUBLE: pyarrobj_to_octvalueNd<NDArray>(octvalue, pyarr, dims); break; + case PyArray_CFLOAT: +#ifdef PYTAVE_USE_OCTAVE_FLOATS + pyarrobj_to_octvalueNd<FloatComplexNDArray>(octvalue, pyarr, dims); + break; +#else + /* fallthrough */ +#endif + case PyArray_CDOUBLE: + pyarrobj_to_octvalueNd<ComplexNDArray>(octvalue, pyarr, dims); + break; default: throw object_convert_exception( PyEval_GetFuncDesc((PyObject*)(pyarr)) + string(" ") @@ -304,6 +350,7 @@ const boost::python::object &py_object) { extract<int> intx(py_object); extract<double> doublex(py_object); + extract<Complex> complexx(py_object); extract<string> stringx(py_object); extract<numeric::array> arrayx(py_object); extract<boost::python::list> listx(py_object); @@ -312,6 +359,8 @@ oct_value = intx(); } else if (doublex.check()) { oct_value = doublex(); + } else if (complexx.check()) { + oct_value = complexx(); } else if (arrayx.check()) { pyarr_to_octvalue(oct_value, (PyArrayObject*)py_object.ptr()); } else if (stringx.check()) {
--- a/test/test.py Tue May 26 08:51:39 2009 +0200 +++ b/test/test.py Tue May 26 11:41:26 2009 +0200 @@ -20,6 +20,8 @@ arr2f = Numeric.array([[1.32, 2, 3, 4],[5,6,7,8]], Numeric.Float32) arr2d = Numeric.array([[1.17, 2, 3, 4],[5,6,7,8]], Numeric.Float) arr3f = Numeric.array([[[1.32, 2, 3, 4],[5,6,7,8]],[[9, 10, 11, 12],[13,14,15,16]]], Numeric.Float32) +arr1c = Numeric.array([[1+2j, 3+4j, 5+6j, 7+0.5j]], Numeric.Complex) +arr1fc = Numeric.array([[1+2j, 3+4j, 5+6j, 7+0.5j]], Numeric.Complex32) alimit_int32 = Numeric.array([[-2147483648, 2147483647]], Numeric.Int32); alimit_int16 = Numeric.array([[-32768, 32767, -32769, 32768]], Numeric.Int16); @@ -182,7 +184,8 @@ testmatrix(arr1fT2) testmatrix(arr1i) testmatrix(arr1b) -testmatrix(arr1i32) +testmatrix(arr1c) +testmatrix(arr1fc) # 2d arrays testmatrix(arr2f)