view octave_to_python.cc @ 75:b0991511a16d

a few fixes to numpy support
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 15 Sep 2009 11:23:36 +0200
parents e3de0f6f1552
children ba609c4d77db
line wrap: on
line source

/*
 *  Copyright 2008 David Grundberg, HÃ¥kan Fors Nilsson
 *  Copyright 2009 VZLU Prague
 *
 *  This file is part of Pytave.
 *
 *  Pytave is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  Pytave is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with Pytave.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <boost/python.hpp>
#include <boost/python/numeric.hpp>
#include <boost/type_traits/integral_constant.hpp>
#undef HAVE_STAT /* both boost::python and octave define HAVE_STAT... */
#include <octave/oct.h>
#include <octave/Matrix.h>
#include <octave/ov.h>
#include <octave/oct-map.h>

#include <iostream>
#include "pytavedefs.h"
#include "arrayobjectdefs.h"
#include "exceptions.h"
#include "octave_to_python.h"

/* From docs:
 *  Note that the names of the element type constants refer to the C data
 *  types, not the Python data types. A Python int is equivalent to a C long,
 *  and a Python float corresponds to a C double . Many of the element types
 *  listed above do not have corresponding Python scalar types
 *  (e.g. PyArray_INT ).
 */

using namespace std;
using namespace boost::python;

namespace pytave {

   template <class PythonPrimitive, class OctaveBase>
   static void copy_octarray_to_pyarrobj(
                                  PyArrayObject *pyarr,
                                  const OctaveBase &matrix,
                                  const unsigned int matindex,
                                  const unsigned int matstride,
                                  const int dimension,
                                  const unsigned int offset) {
      unsigned char *ptr = (unsigned char*) pyarr->data;
      if (dimension == pyarr->nd - 1) {
         // Last dimension, base case
         for (int i = 0; i < pyarr->dimensions[dimension]; i++) {
            *(PythonPrimitive *)&ptr[offset + i*pyarr->strides[dimension]]
               = matrix.elem(matindex + i*matstride);
         }
      } else {
         for (int i = 0; i < pyarr->dimensions[dimension]; i++) {
            copy_octarray_to_pyarrobj<PythonPrimitive, OctaveBase>(
               pyarr,
               matrix,
               matindex + i*matstride,
               matstride * pyarr->dimensions[dimension],
               dimension + 1,
               offset + i*pyarr->strides[dimension]);
         }
      }
   }

   template <>
   void copy_octarray_to_pyarrobj<PyObject *, Cell>(
                                  PyArrayObject *pyarr,
                                  const Cell &matrix,
                                  const unsigned int matindex,
                                  const unsigned int matstride,
                                  const int dimension,
                                  const unsigned int offset) {
      unsigned char *ptr = (unsigned char*) pyarr->data;
      if (dimension == pyarr->nd - 1) {
         // Last dimension, base case
         for (int i = 0; i < pyarr->dimensions[dimension]; i++) {
            object pyobj;
            octvalue_to_pyobj (pyobj, matrix.elem(matindex + i*matstride));
            Py_INCREF (pyobj.ptr());
            *(PyObject **)&ptr[offset + i*pyarr->strides[dimension]]
               = pyobj.ptr();
         }
      } else {
         for (int i = 0; i < pyarr->dimensions[dimension]; i++) {
            copy_octarray_to_pyarrobj<PyObject *, Cell>(
               pyarr,
               matrix,
               matindex + i*matstride,
               matstride * pyarr->dimensions[dimension],
               dimension + 1,
               offset + i*pyarr->strides[dimension]);
         }
      }
   }

   static PyArrayObject *createPyArr(const dim_vector &dims,
                                     int pyarrtype) {
      int len = dims.length();
      npy_intp dimensions[len];
      for (int i = 0; i < dims.length(); i++) {
         dimensions[i] = dims(i);
      }

      return (PyArrayObject *)PyArray_FromDims(
         len, dimensions, pyarrtype);
   }

   template <class PythonPrimitive, class OctaveBase>
   static PyArrayObject *create_array(const OctaveBase &octarr,
                                      int pyarraytype) {
      PyArrayObject *pyarr = createPyArr(octarr.dims(), pyarraytype);
      try {
         copy_octarray_to_pyarrobj
            <PythonPrimitive, OctaveBase>(pyarr, octarr, 0, 1, 0, 0);
      } catch (value_convert_exception &pe) {
         Py_DECREF(pyarr);
         throw;
      }
      return pyarr;
   }

   template <class PythonPrimitive, class OctaveBase>
   static PyArrayObject *create_array(const OctaveBase &octarr,
                                      int pyarraytype, boost::true_type) {
      return create_array<PythonPrimitive, OctaveBase> (octarr, pyarraytype);
   }

   template <class PythonPrimitive, class OctaveBase>
   static PyArrayObject *create_array(const OctaveBase &octarr,
                                      int pyarraytype, boost::false_type) {
      assert(0);
      return 0;
   }

   template <class CLASS, size_t bytes>
   inline static PyArrayObject *create_uint_array(CLASS value) {
      if (bytes == sizeof(int)) {
         boost::integral_constant<bool, bytes == sizeof(int)> inst;
         return create_array<unsigned int, CLASS>(value, PyArray_UINT, inst);
      } else if (bytes == sizeof(char)) {
         boost::integral_constant<bool, bytes == sizeof(char)> inst;
         return create_array<unsigned char, CLASS>(value, PyArray_UBYTE, inst);
      } else if (bytes == sizeof(short)) {
         boost::integral_constant<bool,
            bytes == sizeof(short) && bytes != sizeof(int)> inst;
         return create_array<unsigned short, CLASS>(value, PyArray_USHORT, inst);
      } else {
         ostringstream os;
         os << "Numeric arrays doesn't support unsigned " << (bytes*8)
            << "-bit values on this architecture.";
         throw value_convert_exception(os.str());
      }
   }

   template <class CLASS, size_t bytes>
   inline static PyArrayObject *create_sint_array(CLASS value) {
      if (bytes == sizeof(int)) {
         // We test int first since we prefer int to other datatypes of the
         // same size.
         boost::integral_constant<bool, bytes == sizeof(int)> inst;
         return create_array<signed int, CLASS>(value, PyArray_INT, inst);
      } else if (bytes == sizeof(long)) {
         boost::integral_constant<bool,
            bytes==sizeof(long) && bytes != sizeof(int)> inst;
         return create_array<long, CLASS>(value, PyArray_LONG, inst);
      } else if (bytes == sizeof(char)) {
         boost::integral_constant<bool, bytes == sizeof(char)> inst;
         return create_array<signed char, CLASS>(value, PyArray_SBYTE, inst);
      } else if (bytes == sizeof(short)) {
         boost::integral_constant<bool,
            bytes==sizeof(short) && bytes != sizeof(int)> inst;
         return create_array<signed short, CLASS>(value, PyArray_SHORT, inst);
      } else {
         ostringstream os;
         os << "Numeric arrays doesn't support signed " << (bytes*8)
            << "-bit values on this architecture.";
         throw value_convert_exception(os.str());
      }
   }

   static PyArrayObject *octvalue_to_pyarrobj(const octave_value &matrix) {
      if (matrix.is_double_type ()) {
         if (matrix.is_complex_type ()) {
            return create_array<Complex, ComplexNDArray>
               (matrix.complex_array_value(), PyArray_CDOUBLE);
         } else if (matrix.is_real_type()) {
            return create_array<double, NDArray>(matrix.array_value(),
                                                 PyArray_DOUBLE);
         } else
            throw value_convert_exception("Unknown double matrix type");
      }

      if (matrix.is_single_type ()) {
         if (matrix.is_complex_type ()) {
            return create_array<FloatComplex, FloatComplexNDArray>
               (matrix.float_complex_array_value(), PyArray_CFLOAT);
         } else if (matrix.is_real_type()) {
            return create_array<float, FloatNDArray>(
               matrix.float_array_value(), PyArray_FLOAT);
         } else
            throw value_convert_exception("Unknown float matrix type");
      }

      if (matrix.is_int64_type()) {
         return create_sint_array<int64NDArray, sizeof(int64_t)>(
            matrix.int64_array_value());
      }
      if (matrix.is_uint32_type()) {
         return create_uint_array<uint32NDArray, sizeof(uint32_t)>(
            matrix.uint32_array_value());
      }
      if (matrix.is_int32_type()) {
         return create_sint_array<int32NDArray, sizeof(int32_t)>(
            matrix.int32_array_value());
      }
      if (matrix.is_uint16_type()) {
         return create_uint_array<uint16NDArray, sizeof(uint16_t)>(
            matrix.uint16_array_value());
      }
      if (matrix.is_int16_type()) {
         return create_sint_array<int16NDArray, sizeof(int16_t)>(
            matrix.int16_array_value());
      }
      if (matrix.is_uint8_type()) {
         return create_uint_array<uint8NDArray, sizeof(uint8_t)>(
            matrix.uint8_array_value());
      }
      if (matrix.is_int8_type()) {
         return create_sint_array<int8NDArray, sizeof(int8_t)>(
            matrix.int8_array_value());
      }
      if (matrix.is_bool_type()) {
#ifdef HAVE_NUMPY
         // NumPY has logical arrays, and even provides an old-style #define.
         return create_array<bool, boolNDArray>(
            matrix.bool_array_value(), PyArray_BOOL);
#else
         // Numeric does not support bools, use uint8.
         return create_uint_array<uint8NDArray, sizeof(uint8_t)>(
            matrix.uint8_array_value());
#endif
      }
      if (matrix.is_string()) {
         return create_array<char, charNDArray>(
            matrix.char_array_value(), PyArray_CHAR);
      }
      if (matrix.is_cell()) {
         return create_array<PyObject *, Cell>(
            matrix.cell_value(), PyArray_OBJECT);
      }

      throw value_convert_exception("Octave matrix type not known, "
                                    "conversion not implemented");
   }

   static void octvalue_to_pyarr(boost::python::object &py_object,
                          const octave_value& octvalue) {
      PyArrayObject *pyarr = octvalue_to_pyarrobj(octvalue);
      py_object = object(handle<PyObject>((PyObject *)pyarr));
   }

   static void octmap_to_pyobject(boost::python::object &py_object,
                                  const Octave_map& map) {
      py_object = boost::python::dict();
      string_vector keys = map.keys();

      for(octave_idx_type i = 0 ; i < keys.length(); i++) {
         boost::python::object py_val;

         const Cell c = map.contents(keys[i]);

         octvalue_to_pyarr(py_val, c);

         py_object[keys[i]] = py_val;
      }
   }

   void octvalue_to_pyobj(boost::python::object &py_object,
                          const octave_value& octvalue) {
      if (octvalue.is_undefined()) {
         throw value_convert_exception(
            "Octave value `undefined'. Can not convert to a Python object");
      } else if (octvalue.is_numeric_type() || octvalue.is_string()
                 || octvalue.is_cell()) {
         octvalue_to_pyarr(py_object, octvalue);
      } else if (octvalue.is_map()) {
         octmap_to_pyobject(py_object, octvalue.map_value());
      } else
         throw value_convert_exception(
            "Conversion from Octave value not implemented");
   }

   void octlist_to_pytuple(boost::python::tuple &python_tuple,
                           const octave_value_list &octave_list) {
      boost::python::list seq;
      int length = octave_list.length();
      for (int i = 0; i < length; i++) {
         boost::python::object py_object;
         octvalue_to_pyobj(py_object, octave_list(i));
         seq.append(py_object);
      }
      python_tuple = tuple(seq);
   }
}

/* Emacs
 * Local Variables:
 * fill-column:79
 * coding:utf-8
 * indent-tabs-mode:nil
 * c-basic-offset:3
 * End:
 * vim: set textwidth=79 expandtab shiftwidth=3 :
 */