changeset 70:e3de0f6f1552

experimental NumPy support
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 19 Jun 2009 13:51:36 +0200
parents 4954c14457f2
children af96f7e819e2 077a44d23b54
files ChangeLog arrayobjectdefs.h configure.ac octave_to_python.cc package/pytave.py pytave.cc python_to_octave.cc setup.py.in test/test.py
diffstat 9 files changed, 100 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/ChangeLog	Wed Jun 17 11:49:14 2009 +0200
+++ b/ChangeLog	Fri Jun 19 13:51:36 2009 +0200
@@ -1,3 +1,16 @@
+2009-06-19  Jaroslav Hajek  <highegg@gmail.com>
+
+	* configure.ac: Support --enable-numpy
+	* setup.py.in: Dynamically determine NumPy include path.
+	* pytave.cc (get_module_name): New function.
+	* octave_to_python.cc (octvalue_to_pyarrobj): 
+	Support bool arrays with NumPy.
+	* python_to_octave.cc (pyarr_to_octvalue, 
+	copy_pyarrobj_to_octarray_boot): Likewise.
+	* package/pytave.py: Dynamically import Numeric,
+	forward to numpy.oldnumeric if run with NumPy.
+	* test/test.py: Update some tests.
+
 2009-06-17  Jaroslav Hajek  <highegg@gmail.com>
 
 	* package/pytave.py (stripdict): New function.
--- a/arrayobjectdefs.h	Wed Jun 17 11:49:14 2009 +0200
+++ b/arrayobjectdefs.h	Fri Jun 19 13:51:36 2009 +0200
@@ -29,7 +29,16 @@
 #endif
 #define PY_ARRAY_UNIQUE_SYMBOL pytave_array_symbol
 #include <Python.h>
+#ifdef HAVE_NUMPY
+#include <numpy/oldnumeric.h>
+#include <numpy/old_defines.h>
+// Avoid deprecation warnings from NumPy
+#undef PyArray_FromDims 
+#define PyArray_FromDims PyArray_SimpleNew
+#else
 #include <Numeric/arrayobject.h>
+typedef int npy_intp;
+#endif
 
 /* Emacs
  * Local Variables:
--- a/configure.ac	Wed Jun 17 11:49:14 2009 +0200
+++ b/configure.ac	Fri Jun 19 13:51:36 2009 +0200
@@ -17,6 +17,17 @@
 
 AC_PRESERVE_HELP_ORDER
 
+AC_ARG_ENABLE(numpy,
+	      [AS_HELP_STRING([--enable-numpy],
+			      [use NumPy module (experimental)
+			      @<:@default=no@:>@])],
+			      [pytave_enable_numpy="$enableval"],
+			      [pytave_enable_numpy=no]) dnl TODO: Check?
+ 
+if test "$pytave_enable_numpy" == "yes" ; then
+	AC_DEFINE([HAVE_NUMPY], 1, [Define if using NumPy])
+fi
+
 pytave_libs_ok=
 
 AX_OCTAVE([], [], [pytave_libs_ok=no])
@@ -85,6 +96,7 @@
 PYTAVE_OCTAVE_RPATH="$OCTAVE_LIBRARYDIR"
 AC_SUBST(PYTAVE_OCTAVE_RPATH)
 AC_SUBST(PYTAVE_MODULE_INSTALL_PATH)
+AC_SUBST(pytave_enable_numpy) 
 
 # Substitutes for the Jamfile. XXX: Replace lib*.so with OS independent name.
 AC_SUBST(JAM_LIBOCTAVE, $OCTAVE_LIBRARYDIR/liboctave.so)
--- a/octave_to_python.cc	Wed Jun 17 11:49:14 2009 +0200
+++ b/octave_to_python.cc	Fri Jun 19 13:51:36 2009 +0200
@@ -108,7 +108,7 @@
    static PyArrayObject *createPyArr(const dim_vector &dims,
                                      int pyarrtype) {
       int len = dims.length();
-      int dimensions[len];
+      npy_intp dimensions[len];
       for (int i = 0; i < dims.length(); i++) {
          dimensions[i] = dims(i);
       }
@@ -242,9 +242,15 @@
             matrix.int8_array_value());
       }
       if (matrix.is_bool_type()) {
+#ifdef HAVE_NUMPY
+         // NumPY has logical arrays, and even provides an old-style #define.
+         return create_array<bool, boolNDArray>(
+            matrix.bool_array_value(), PyArray_BOOL);
+#else
          // Numeric does not support bools, use uint8.
          return create_uint_array<uint8NDArray, sizeof(uint8_t)>(
             matrix.uint8_array_value());
+#endif
       }
       if (matrix.is_string()) {
          return create_array<char, charNDArray>(
--- a/package/pytave.py	Wed Jun 17 11:49:14 2009 +0200
+++ b/package/pytave.py	Fri Jun 19 13:51:36 2009 +0200
@@ -23,7 +23,7 @@
 import _pytave
 import UserDict
 import sys
-import Numeric
+
 
 arg0 = sys.argv[0]
 interactive = sys.stdin.isatty() and (arg0 == '' or arg0 == '-')
@@ -32,6 +32,18 @@
 (OctaveError, ValueConvertError, ObjectConvertError, ParseError, \
  VarNameError) = _pytave.get_exceptions();
 
+# Dynamic import. *Must* go after _pytave.init() !
+__modname__ = _pytave.get_module_name()
+if __modname__ == 'numpy':
+    from numpy import oldnumeric as Numeric
+elif __modname__ == 'Numeric':
+    import Numeric
+elif __modname__ == 'numarray':
+    # FIXME: Is this OK?
+    import numarray as Numeric
+else:
+    raise ImportError("Failed to import module: %s" % __modname__)
+
 def feval(nargout, funcname, *arguments):
 
 	"""Executes an Octave function called funcname.
--- a/pytave.cc	Wed Jun 17 11:49:14 2009 +0200
+++ b/pytave.cc	Fri Jun 19 13:51:36 2009 +0200
@@ -80,6 +80,13 @@
       // This is actually a macro that becomes a block expression. If an error
       // occurs, e.g. Numeric Array not installed, an exception is set.
       import_array()
+#ifdef HAVE_NUMPY
+      numeric::array::set_module_and_type ("numpy", "ndarray");
+#endif
+   }
+
+   string get_module_name () {
+      return numeric::array::get_module_name ();
    }
 
    boost::python::tuple get_exceptions() {
@@ -301,6 +308,7 @@
    using namespace boost::python;
 
    def("init", pytave::init);
+   def("get_module_name", pytave::get_module_name);
    def("feval", pytave::func_eval);
    def("eval", pytave::str_eval);
    def("getvar", pytave::getvar);
--- a/python_to_octave.cc	Wed Jun 17 11:49:14 2009 +0200
+++ b/python_to_octave.cc	Fri Jun 19 13:51:36 2009 +0200
@@ -179,6 +179,10 @@
          /* Commonly Numeric.array(..., Numeric.Complex) */
          ARRAYCASE(PyArray_CDOUBLE, Complex)
 
+#ifdef HAVE_NUMPY
+         ARRAYCASE(PyArray_BOOL, bool)
+#endif
+
          ARRAYCASE(PyArray_OBJECT, PyObject *)
 
          default:
@@ -271,6 +275,11 @@
             // FIXME: is the following needed?
             octvalue = octvalue.convert_to_str(true, true, '"');
             break;
+#ifdef HAVE_NUMPY
+         case PyArray_BOOL:
+            pyarrobj_to_octvalueNd<boolNDArray>(octvalue, pyarr, dims);
+            break;
+#endif
          case PyArray_OBJECT:
             pyarrobj_to_octvalueNd<Cell>(octvalue, pyarr, dims);
             break;
--- a/setup.py.in	Wed Jun 17 11:49:14 2009 +0200
+++ b/setup.py.in	Fri Jun 19 13:51:36 2009 +0200
@@ -2,7 +2,19 @@
 # -*- coding: utf-8; c-basic-offset: 3; indent-tabs-mode: nil; tab-width: 3; -*-
 # @configure_input@
 
-from distutils.core import setup, Extension
+from distutils.core import setup, Extension, DistutilsModuleError
+
+include_dirs = ['/usr/local/include/octave-3.1.55', 
+	        '/home/hajek/devel/pytave-repo/pytave', '.'] # Python always included.
+
+# check for numpy. If it exists, define the path.
+
+if '@pytave_enable_numpy@' == 'yes':
+    try:
+        from numpy import get_include
+        include_dirs.append(get_include())
+    except ImportError:
+        raise DistutilsModuleError("could not found numpy")
 
 setup(
    name = 'pytave',
@@ -31,8 +43,8 @@
          
          # TODO: Check whether paths work on Windows or not.
          # The file separator might be wrong. (Must be / in setup.cfg)
-         include_dirs = ['@OCTAVE_INCLUDEDIR@', '@abs_builddir@', '@srcdir@'], # Python always included.
-         define_macros = [('HAVE_CONFIG_H', '1')],
+         include_dirs = include_dirs,
+	 define_macros = [('HAVE_CONFIG_H', '1')],
          library_dirs = ['@OCTAVE_LIBRARYDIR@'],
          runtime_library_dirs = ['@PYTAVE_OCTAVE_RPATH@'],
          libraries = ['octinterp', 'octave', 'cruft', '@BOOST_PYTHON_LIB@']
--- a/test/test.py	Wed Jun 17 11:49:14 2009 +0200
+++ b/test/test.py	Fri Jun 19 13:51:36 2009 +0200
@@ -2,7 +2,7 @@
 # -*- coding:utf-8 -*-
 
 import pytave
-import Numeric
+from pytave import Numeric
 import traceback
 
 print "No messages indicates test pass."
@@ -40,6 +40,9 @@
 
 pytave.feval(1, "test_return", 1)
 
+def equals(a,b):
+    return Numeric.alltrue(Numeric.ravel(a == b))
+
 def fail(msg, exc=None):
 	print "FAIL:", msg
 	traceback.print_stack()
@@ -50,7 +53,7 @@
 def testequal(value):
 	try:
 		nvalue, = pytave.feval(1, "test_return", value)
-		if nvalue != value:
+		if not equals(value, nvalue):
 			fail("as %s != %s" % (value, nvalue))
 	except TypeError, e:
 		fail(value, e)
@@ -58,7 +61,7 @@
 def testexpect(value, expected):
 	try:
 		nvalue, = pytave.feval(1, "test_return", value)
-		if nvalue != expected:
+		if not equals(value, nvalue):
 			fail("sent in %s, expecting %s, got %s", (value, expected, nvalue))
 	except TypeError, e:
 		fail(value, e)
@@ -66,9 +69,11 @@
 def testmatrix(value):
 	try:
 		nvalue, = pytave.feval(1, "test_return", value)
-		class1 = pytave.feval(1, "class", value)
-		class2 = pytave.feval(1, "class", nvalue)
-		if nvalue != value:
+		class1, = pytave.feval(1, "class", value)
+		class1 = class1.tostring()
+		class2, = pytave.feval(1, "class", nvalue)
+		class2 = class2.tostring()
+		if not equals(value, nvalue):
 			fail("as %s != %s" % (value, nvalue))
 		if value.shape != nvalue.shape:
 			fail("Size check failed for: %s. Expected shape %s, got %s  with shape %s" \
@@ -115,7 +120,7 @@
 def testevalexpect(numargout, code, expectations):
 	try:
 		results = pytave.eval(numargout, code);
-		if results != expectations:
+		if not equals(results, expectations):
 			fail("eval: %s : because %s != %s" % (code, results, expectations))
 	except Exception, e:
 		fail("eval: %s" % code, e)
@@ -145,12 +150,14 @@
     def sloppy_factorial(x):
 	pytave.locals["x"] = x
 	xm1, = pytave.eval(1,"x-1")
+	xm1 = xm1.toscalar()
 	if xm1 > 0:
 	    fxm1 = sloppy_factorial(xm1)
 	else:
 	    fxm1 = 1
 	pytave.locals["fxm1"] = fxm1
 	fx, = pytave.eval(1,"x * fxm1")
+	fx = fx.toscalar()
 	return fx
 
     try: