Mercurial > pytave
view python_to_octave.cc @ 79:d60165bfc849
Support 0D Numeric arrays
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Fri, 18 Sep 2009 10:31:15 +0200 |
parents | b0991511a16d |
children | 2e8b52a5e1b1 |
line wrap: on
line source
/* * Copyright 2008 David Grundberg, HÃ¥kan Fors Nilsson * Copyright 2009 VZLU Prague * * This file is part of Pytave. * * Pytave is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Pytave is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with Pytave. If not, see <http://www.gnu.org/licenses/>. */ #include <iostream> #include <boost/python.hpp> #include <boost/python/numeric.hpp> #include <boost/type_traits/integral_constant.hpp> #undef HAVE_STAT /* both boost.python and octave defines HAVE_STAT... */ #include <octave/oct.h> #include <octave/oct-map.h> #include <octave/Cell.h> #include <octave/Matrix.h> #include <octave/ov.h> #include "pytavedefs.h" #include "arrayobjectdefs.h" #include "exceptions.h" using namespace std; using namespace boost::python; namespace pytave { void pyobj_to_octvalue(octave_value &oct_value, const boost::python::object &py_object); template <class PythonPrimitive, class OctaveBase> static void copy_pyarrobj_to_octarray(OctaveBase &matrix, const PyArrayObject* const pyarr, const int unsigned matindex, const unsigned int matstride, const int dimension, const unsigned int offset) { unsigned char *ptr = (unsigned char*) pyarr->data; if (dimension == pyarr->nd - 1) { // Last dimension, base case for (int i = 0; i < pyarr->dimensions[dimension]; i++) { matrix.elem(matindex + i*matstride) = *(PythonPrimitive*) &ptr[offset + i*pyarr->strides[dimension]]; } } else if (pyarr->nd == 0) { matrix.elem(0) = *(PythonPrimitive*) ptr; } else { for (int i = 0; i < pyarr->dimensions[dimension]; i++) { copy_pyarrobj_to_octarray<PythonPrimitive, OctaveBase>( matrix, pyarr, matindex + i*matstride, matstride * pyarr->dimensions[dimension], dimension + 1, offset + i*pyarr->strides[dimension]); } } } template <> void copy_pyarrobj_to_octarray<PyObject *, Cell>(Cell &matrix, const PyArrayObject* const pyarr, const int unsigned matindex, const unsigned int matstride, const int dimension, const unsigned int offset) { unsigned char *ptr = (unsigned char*) pyarr->data; if (dimension == pyarr->nd - 1) { // Last dimension, base case for (int i = 0; i < pyarr->dimensions[dimension]; i++) { PyObject *pobj = *(PyObject **) &ptr[offset + i*pyarr->strides[dimension]]; pyobj_to_octvalue (matrix.elem(matindex + i*matstride), object(handle<PyObject> (borrowed (pobj)))); } } else if (pyarr->nd == 0) { PyObject *pobj = *(PyObject **) ptr; pyobj_to_octvalue (matrix.elem(0), object(handle<PyObject> (borrowed (pobj)))); } else { for (int i = 0; i < pyarr->dimensions[dimension]; i++) { copy_pyarrobj_to_octarray<PyObject *, Cell>( matrix, pyarr, matindex + i*matstride, matstride * pyarr->dimensions[dimension], dimension + 1, offset + i*pyarr->strides[dimension]); } } } template <class PythonPrimitive, class OctaveBase> static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix, const PyArrayObject* const pyarr, const boost::true_type&) { copy_pyarrobj_to_octarray<PythonPrimitive, OctaveBase> (matrix, pyarr, 0, 1, 0, 0); } template <class PythonPrimitive, class OctaveBase> static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix, const PyArrayObject* const pyarr, const boost::false_type&) { assert(0); } template <class X, class Y> class matching_type : public boost::false_type { }; template <class X> class matching_type<X, X> : public boost::true_type { }; template <class X> class matching_type<X, octave_int<X> > : public boost::true_type { }; template <> class matching_type<float, double> : public boost::true_type { }; template <> class matching_type<FloatComplex, Complex> : public boost::true_type { }; template <> class matching_type<PyObject *, octave_value> : public boost::true_type { }; template <class PythonPrimitive, class OctaveBase> static void copy_pyarrobj_to_octarray_dispatch(OctaveBase &matrix, const PyArrayObject* const pyarr) { matching_type<PythonPrimitive, typename OctaveBase::element_type> inst; copy_pyarrobj_to_octarray_dispatch<PythonPrimitive, OctaveBase> (matrix, pyarr, inst); } template <class OctaveBase> static void copy_pyarrobj_to_octarray_boot(OctaveBase &matrix, const PyArrayObject* const pyarr) { #define ARRAYCASE(AC_pyarrtype, AC_primitive) case AC_pyarrtype: \ copy_pyarrobj_to_octarray_dispatch<AC_primitive, OctaveBase>\ (matrix, pyarr); \ break; \ // Prefer int to other types of the same size. // E.g. on 32-bit x86 architectures: sizeof(long) == sizeof(int). int type_num = pyarr->descr->type_num; switch (type_num) { case PyArray_LONG: if (sizeof(long) == sizeof(int)) { type_num = PyArray_INT; } break; case PyArray_SHORT: if (sizeof(short) == sizeof(int)) { type_num = PyArray_INT; } break; case PyArray_USHORT: if (sizeof(unsigned short) == sizeof(unsigned int)) { type_num = PyArray_UINT; } break; } switch (type_num) { ARRAYCASE(PyArray_CHAR, char) ARRAYCASE(PyArray_UBYTE, unsigned char) ARRAYCASE(PyArray_SBYTE, signed char) ARRAYCASE(PyArray_SHORT, signed short) ARRAYCASE(PyArray_USHORT, unsigned short) ARRAYCASE(PyArray_INT, signed int) ARRAYCASE(PyArray_UINT, unsigned int) ARRAYCASE(PyArray_LONG, signed long) /* Commonly Numeric.array(..., Numeric.Float32) */ ARRAYCASE(PyArray_FLOAT, float) /* Commonly Numeric.array(..., Numeric.Float) */ ARRAYCASE(PyArray_DOUBLE, double) /* Commonly Numeric.array(..., Numeric.Complex32) */ ARRAYCASE(PyArray_CFLOAT, FloatComplex) /* Commonly Numeric.array(..., Numeric.Complex) */ ARRAYCASE(PyArray_CDOUBLE, Complex) #ifdef HAVE_NUMPY ARRAYCASE(PyArray_BOOL, bool) #endif ARRAYCASE(PyArray_OBJECT, PyObject *) default: throw object_convert_exception( PyEval_GetFuncName((PyObject*)pyarr) + (PyEval_GetFuncDesc((PyObject*)pyarr) + string(": Unsupported Python array type"))); } } template <class OctaveBase> static void pyarrobj_to_octvalueNd(octave_value &octvalue, const PyArrayObject* const pyarr, dim_vector dims) { OctaveBase array(dims); copy_pyarrobj_to_octarray_boot<OctaveBase>(array, pyarr); octvalue = array; } static void pyarr_to_octvalue(octave_value &octvalue, const PyArrayObject *pyarr) { dim_vector dims; switch (pyarr->nd) { case 0: dims = dim_vector (1, 1); break; case 1: // Always make PyArray vectors row vectors. dims = dim_vector(1, pyarr->dimensions[0]); break; default: dims.resize(pyarr->nd); for (int d = 0; d < pyarr->nd; d++) { dims(d) = pyarr->dimensions[d]; } break; } switch (pyarr->descr->type_num) { case PyArray_UBYTE: case PyArray_USHORT: case PyArray_UINT: switch (pyarr->descr->elsize) { case 1: pyarrobj_to_octvalueNd<uint8NDArray>(octvalue, pyarr, dims); break; case 2: pyarrobj_to_octvalueNd<uint16NDArray>(octvalue, pyarr, dims); break; case 4: pyarrobj_to_octvalueNd<uint32NDArray>(octvalue, pyarr, dims); break; default: throw object_convert_exception("Unknown unsigned integer."); } case PyArray_SBYTE: case PyArray_SHORT: case PyArray_INT: case PyArray_LONG: switch (pyarr->descr->elsize) { case 1: pyarrobj_to_octvalueNd<int8NDArray>(octvalue, pyarr, dims); break; case 2: pyarrobj_to_octvalueNd<int16NDArray>(octvalue, pyarr, dims); break; case 4: pyarrobj_to_octvalueNd<int32NDArray>(octvalue, pyarr, dims); break; case 8: pyarrobj_to_octvalueNd<int64NDArray>(octvalue, pyarr, dims); break; default: throw object_convert_exception("Unknown integer."); } break; case PyArray_FLOAT: pyarrobj_to_octvalueNd<FloatNDArray>(octvalue, pyarr, dims); break; case PyArray_DOUBLE: pyarrobj_to_octvalueNd<NDArray>(octvalue, pyarr, dims); break; case PyArray_CFLOAT: pyarrobj_to_octvalueNd<FloatComplexNDArray>(octvalue, pyarr, dims); break; case PyArray_CDOUBLE: pyarrobj_to_octvalueNd<ComplexNDArray>(octvalue, pyarr, dims); break; case PyArray_CHAR: pyarrobj_to_octvalueNd<charNDArray>(octvalue, pyarr, dims); // FIXME: is the following needed? octvalue = octvalue.convert_to_str(true, true, '"'); break; #ifdef HAVE_NUMPY case PyArray_BOOL: pyarrobj_to_octvalueNd<boolNDArray>(octvalue, pyarr, dims); break; #endif case PyArray_OBJECT: pyarrobj_to_octvalueNd<Cell>(octvalue, pyarr, dims); break; default: throw object_convert_exception( PyEval_GetFuncDesc((PyObject*)(pyarr)) + string(" ") + PyEval_GetFuncName((PyObject*)(pyarr)) + ": Encountered unsupported Python array"); break; } } static void pylist_to_cellarray(octave_value &oct_value, const boost::python::list &list) { octave_idx_type length = boost::python::extract<octave_idx_type>( list.attr("__len__")()); octave_value_list values; for(octave_idx_type i = 0; i < length; i++) { octave_value val; pyobj_to_octvalue(val, list[i]); values.append(val); } oct_value = Cell(values); } static void pydict_to_octmap(octave_value &oct_value, const boost::python::dict &dict) { boost::python::list list = dict.items(); octave_idx_type length = boost::python::extract<octave_idx_type>( list.attr("__len__")()); dim_vector dims = dim_vector(1, 1); Array<octave_value> vals (length); Array<std::string> keys (length); // Extract all keys and convert values. Remember whether dimensions // match. for(octave_idx_type i = 0; i < length; i++) { std::string& key = keys(i); boost::python::tuple tuple = boost::python::extract<boost::python::tuple>(list[i])(); boost::python::extract<std::string> str(tuple[0]); if(!str.check()) { throw object_convert_exception( string("Can not convert key of type ") + PyEval_GetFuncName(boost::python::object(tuple[0]).ptr()) + PyEval_GetFuncDesc(boost::python::object(tuple[0]).ptr()) + " to a structure field name. Field names must be strings."); } key = str(); if (!valid_identifier(key)) { throw object_convert_exception( string("Can not convert key `") + key + "' to a structure " "field name. Field names must be valid Octave identifiers."); } octave_value& val = vals(i); pyobj_to_octvalue(val, tuple[1]); if(val.is_cell()) { if(i == 0) { dims = val.dims(); } else if (val.numel() != 1 && val.dims() != dims){ throw object_convert_exception( "Dimensions of the struct fields do not match"); } } } Octave_map map = Octave_map(dims); for(octave_idx_type i = 0; i < length; i++) { std::string& key = keys(i); octave_value val = vals(i); if(val.is_cell()) { const Cell c = val.cell_value(); if (c.numel () == 1) { map.assign(key, Cell(dims, c(0))); } else { map.assign(key, c); } } else { map.assign(key, Cell(dims, val)); } } oct_value = map; } void pyobj_to_octvalue(octave_value &oct_value, const boost::python::object &py_object) { extract<int> intx(py_object); extract<double> doublex(py_object); extract<Complex> complexx(py_object); extract<string> stringx(py_object); extract<numeric::array> arrayx(py_object); extract<boost::python::list> listx(py_object); extract<boost::python::dict> dictx(py_object); if (intx.check()) { oct_value = intx(); } else if (doublex.check()) { oct_value = doublex(); } else if (complexx.check()) { oct_value = complexx(); } else if (arrayx.check()) { pyarr_to_octvalue(oct_value, (PyArrayObject*)py_object.ptr()); } else if (stringx.check()) { oct_value = stringx(); } else if (listx.check()) { pylist_to_cellarray(oct_value, (boost::python::list&)py_object); } else if (dictx.check()) { pydict_to_octmap(oct_value, (boost::python::dict&)py_object); } else { throw object_convert_exception( PyEval_GetFuncName(py_object.ptr()) + (PyEval_GetFuncDesc(py_object.ptr()) + string(": Unsupported Python object type, " "cannot convert to Octave value"))); } } void pytuple_to_octlist(octave_value_list &octave_list, const boost::python::tuple &python_tuple) { int length = extract<int>(python_tuple.attr("__len__")()); for (int i = 0; i < length; i++) { pyobj_to_octvalue(octave_list(i), python_tuple[i]); } } } /* Emacs * Local Variables: * fill-column:79 * coding:utf-8 * indent-tabs-mode:nil * c-basic-offset:3 * End: * vim: set textwidth=79 expandtab shiftwidth=3 : */