Mercurial > pytave
changeset 70:e3de0f6f1552
experimental NumPy support
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Fri, 19 Jun 2009 13:51:36 +0200 |
parents | 4954c14457f2 |
children | af96f7e819e2 077a44d23b54 |
files | ChangeLog arrayobjectdefs.h configure.ac octave_to_python.cc package/pytave.py pytave.cc python_to_octave.cc setup.py.in test/test.py |
diffstat | 9 files changed, 100 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/ChangeLog Wed Jun 17 11:49:14 2009 +0200 +++ b/ChangeLog Fri Jun 19 13:51:36 2009 +0200 @@ -1,3 +1,16 @@ +2009-06-19 Jaroslav Hajek <highegg@gmail.com> + + * configure.ac: Support --enable-numpy + * setup.py.in: Dynamically determine NumPy include path. + * pytave.cc (get_module_name): New function. + * octave_to_python.cc (octvalue_to_pyarrobj): + Support bool arrays with NumPy. + * python_to_octave.cc (pyarr_to_octvalue, + copy_pyarrobj_to_octarray_boot): Likewise. + * package/pytave.py: Dynamically import Numeric, + forward to numpy.oldnumeric if run with NumPy. + * test/test.py: Update some tests. + 2009-06-17 Jaroslav Hajek <highegg@gmail.com> * package/pytave.py (stripdict): New function.
--- a/arrayobjectdefs.h Wed Jun 17 11:49:14 2009 +0200 +++ b/arrayobjectdefs.h Fri Jun 19 13:51:36 2009 +0200 @@ -29,7 +29,16 @@ #endif #define PY_ARRAY_UNIQUE_SYMBOL pytave_array_symbol #include <Python.h> +#ifdef HAVE_NUMPY +#include <numpy/oldnumeric.h> +#include <numpy/old_defines.h> +// Avoid deprecation warnings from NumPy +#undef PyArray_FromDims +#define PyArray_FromDims PyArray_SimpleNew +#else #include <Numeric/arrayobject.h> +typedef int npy_intp; +#endif /* Emacs * Local Variables:
--- a/configure.ac Wed Jun 17 11:49:14 2009 +0200 +++ b/configure.ac Fri Jun 19 13:51:36 2009 +0200 @@ -17,6 +17,17 @@ AC_PRESERVE_HELP_ORDER +AC_ARG_ENABLE(numpy, + [AS_HELP_STRING([--enable-numpy], + [use NumPy module (experimental) + @<:@default=no@:>@])], + [pytave_enable_numpy="$enableval"], + [pytave_enable_numpy=no]) dnl TODO: Check? + +if test "$pytave_enable_numpy" == "yes" ; then + AC_DEFINE([HAVE_NUMPY], 1, [Define if using NumPy]) +fi + pytave_libs_ok= AX_OCTAVE([], [], [pytave_libs_ok=no]) @@ -85,6 +96,7 @@ PYTAVE_OCTAVE_RPATH="$OCTAVE_LIBRARYDIR" AC_SUBST(PYTAVE_OCTAVE_RPATH) AC_SUBST(PYTAVE_MODULE_INSTALL_PATH) +AC_SUBST(pytave_enable_numpy) # Substitutes for the Jamfile. XXX: Replace lib*.so with OS independent name. AC_SUBST(JAM_LIBOCTAVE, $OCTAVE_LIBRARYDIR/liboctave.so)
--- a/octave_to_python.cc Wed Jun 17 11:49:14 2009 +0200 +++ b/octave_to_python.cc Fri Jun 19 13:51:36 2009 +0200 @@ -108,7 +108,7 @@ static PyArrayObject *createPyArr(const dim_vector &dims, int pyarrtype) { int len = dims.length(); - int dimensions[len]; + npy_intp dimensions[len]; for (int i = 0; i < dims.length(); i++) { dimensions[i] = dims(i); } @@ -242,9 +242,15 @@ matrix.int8_array_value()); } if (matrix.is_bool_type()) { +#ifdef HAVE_NUMPY + // NumPY has logical arrays, and even provides an old-style #define. + return create_array<bool, boolNDArray>( + matrix.bool_array_value(), PyArray_BOOL); +#else // Numeric does not support bools, use uint8. return create_uint_array<uint8NDArray, sizeof(uint8_t)>( matrix.uint8_array_value()); +#endif } if (matrix.is_string()) { return create_array<char, charNDArray>(
--- a/package/pytave.py Wed Jun 17 11:49:14 2009 +0200 +++ b/package/pytave.py Fri Jun 19 13:51:36 2009 +0200 @@ -23,7 +23,7 @@ import _pytave import UserDict import sys -import Numeric + arg0 = sys.argv[0] interactive = sys.stdin.isatty() and (arg0 == '' or arg0 == '-') @@ -32,6 +32,18 @@ (OctaveError, ValueConvertError, ObjectConvertError, ParseError, \ VarNameError) = _pytave.get_exceptions(); +# Dynamic import. *Must* go after _pytave.init() ! +__modname__ = _pytave.get_module_name() +if __modname__ == 'numpy': + from numpy import oldnumeric as Numeric +elif __modname__ == 'Numeric': + import Numeric +elif __modname__ == 'numarray': + # FIXME: Is this OK? + import numarray as Numeric +else: + raise ImportError("Failed to import module: %s" % __modname__) + def feval(nargout, funcname, *arguments): """Executes an Octave function called funcname.
--- a/pytave.cc Wed Jun 17 11:49:14 2009 +0200 +++ b/pytave.cc Fri Jun 19 13:51:36 2009 +0200 @@ -80,6 +80,13 @@ // This is actually a macro that becomes a block expression. If an error // occurs, e.g. Numeric Array not installed, an exception is set. import_array() +#ifdef HAVE_NUMPY + numeric::array::set_module_and_type ("numpy", "ndarray"); +#endif + } + + string get_module_name () { + return numeric::array::get_module_name (); } boost::python::tuple get_exceptions() { @@ -301,6 +308,7 @@ using namespace boost::python; def("init", pytave::init); + def("get_module_name", pytave::get_module_name); def("feval", pytave::func_eval); def("eval", pytave::str_eval); def("getvar", pytave::getvar);
--- a/python_to_octave.cc Wed Jun 17 11:49:14 2009 +0200 +++ b/python_to_octave.cc Fri Jun 19 13:51:36 2009 +0200 @@ -179,6 +179,10 @@ /* Commonly Numeric.array(..., Numeric.Complex) */ ARRAYCASE(PyArray_CDOUBLE, Complex) +#ifdef HAVE_NUMPY + ARRAYCASE(PyArray_BOOL, bool) +#endif + ARRAYCASE(PyArray_OBJECT, PyObject *) default: @@ -271,6 +275,11 @@ // FIXME: is the following needed? octvalue = octvalue.convert_to_str(true, true, '"'); break; +#ifdef HAVE_NUMPY + case PyArray_BOOL: + pyarrobj_to_octvalueNd<boolNDArray>(octvalue, pyarr, dims); + break; +#endif case PyArray_OBJECT: pyarrobj_to_octvalueNd<Cell>(octvalue, pyarr, dims); break;
--- a/setup.py.in Wed Jun 17 11:49:14 2009 +0200 +++ b/setup.py.in Fri Jun 19 13:51:36 2009 +0200 @@ -2,7 +2,19 @@ # -*- coding: utf-8; c-basic-offset: 3; indent-tabs-mode: nil; tab-width: 3; -*- # @configure_input@ -from distutils.core import setup, Extension +from distutils.core import setup, Extension, DistutilsModuleError + +include_dirs = ['/usr/local/include/octave-3.1.55', + '/home/hajek/devel/pytave-repo/pytave', '.'] # Python always included. + +# check for numpy. If it exists, define the path. + +if '@pytave_enable_numpy@' == 'yes': + try: + from numpy import get_include + include_dirs.append(get_include()) + except ImportError: + raise DistutilsModuleError("could not found numpy") setup( name = 'pytave', @@ -31,8 +43,8 @@ # TODO: Check whether paths work on Windows or not. # The file separator might be wrong. (Must be / in setup.cfg) - include_dirs = ['@OCTAVE_INCLUDEDIR@', '@abs_builddir@', '@srcdir@'], # Python always included. - define_macros = [('HAVE_CONFIG_H', '1')], + include_dirs = include_dirs, + define_macros = [('HAVE_CONFIG_H', '1')], library_dirs = ['@OCTAVE_LIBRARYDIR@'], runtime_library_dirs = ['@PYTAVE_OCTAVE_RPATH@'], libraries = ['octinterp', 'octave', 'cruft', '@BOOST_PYTHON_LIB@']
--- a/test/test.py Wed Jun 17 11:49:14 2009 +0200 +++ b/test/test.py Fri Jun 19 13:51:36 2009 +0200 @@ -2,7 +2,7 @@ # -*- coding:utf-8 -*- import pytave -import Numeric +from pytave import Numeric import traceback print "No messages indicates test pass." @@ -40,6 +40,9 @@ pytave.feval(1, "test_return", 1) +def equals(a,b): + return Numeric.alltrue(Numeric.ravel(a == b)) + def fail(msg, exc=None): print "FAIL:", msg traceback.print_stack() @@ -50,7 +53,7 @@ def testequal(value): try: nvalue, = pytave.feval(1, "test_return", value) - if nvalue != value: + if not equals(value, nvalue): fail("as %s != %s" % (value, nvalue)) except TypeError, e: fail(value, e) @@ -58,7 +61,7 @@ def testexpect(value, expected): try: nvalue, = pytave.feval(1, "test_return", value) - if nvalue != expected: + if not equals(value, nvalue): fail("sent in %s, expecting %s, got %s", (value, expected, nvalue)) except TypeError, e: fail(value, e) @@ -66,9 +69,11 @@ def testmatrix(value): try: nvalue, = pytave.feval(1, "test_return", value) - class1 = pytave.feval(1, "class", value) - class2 = pytave.feval(1, "class", nvalue) - if nvalue != value: + class1, = pytave.feval(1, "class", value) + class1 = class1.tostring() + class2, = pytave.feval(1, "class", nvalue) + class2 = class2.tostring() + if not equals(value, nvalue): fail("as %s != %s" % (value, nvalue)) if value.shape != nvalue.shape: fail("Size check failed for: %s. Expected shape %s, got %s with shape %s" \ @@ -115,7 +120,7 @@ def testevalexpect(numargout, code, expectations): try: results = pytave.eval(numargout, code); - if results != expectations: + if not equals(results, expectations): fail("eval: %s : because %s != %s" % (code, results, expectations)) except Exception, e: fail("eval: %s" % code, e) @@ -145,12 +150,14 @@ def sloppy_factorial(x): pytave.locals["x"] = x xm1, = pytave.eval(1,"x-1") + xm1 = xm1.toscalar() if xm1 > 0: fxm1 = sloppy_factorial(xm1) else: fxm1 = 1 pytave.locals["fxm1"] = fxm1 fx, = pytave.eval(1,"x * fxm1") + fx = fx.toscalar() return fx try: