changeset 14951:4c9fd3e31436

Start of jit support for double matricies
author Max Brister <max@2bass.com>
date Thu, 14 Jun 2012 16:38:06 -0500
parents 7ab3ac5c676c
children e3696f2c6da6
files liboctave/Array.h liboctave/MArray.h liboctave/dNDArray.h liboctave/dim-vector.h liboctave/oct-refcount.h src/pt-jit.cc src/pt-jit.h
diffstat 7 files changed, 543 insertions(+), 98 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/Array.h	Mon Jun 11 20:11:20 2012 -0500
+++ b/liboctave/Array.h	Thu Jun 14 16:38:06 2012 -0500
@@ -164,6 +164,14 @@
       return &nr;
     }
 
+protected:
+
+  // For jit support
+  Array (T *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep)
+    : dimensions (adims),
+      rep (reinterpret_cast<typename Array<T>::ArrayRep *> (arep)),
+      slice_data (sdata), slice_len (slen) {}
+
 public:
 
   // Empty ctor (0x0).
@@ -693,6 +701,16 @@
   // supposedly equal dimensions (e.g. structs in the interpreter).
   bool optimize_dimensions (const dim_vector& dv);
 
+  // WARNING: Only call these functions from jit
+
+  int *jit_ref_count (void) { return rep->count.get (); }
+
+  T *jit_slice_data (void) const { return slice_data; }
+
+  octave_idx_type *jit_dimensions (void) const { return dimensions.to_jit (); }
+
+  void *jit_array_rep (void) const { return rep; }
+
 private:
 
   void resize2 (octave_idx_type nr, octave_idx_type nc, const T& rfv);
--- a/liboctave/MArray.h	Mon Jun 11 20:11:20 2012 -0500
+++ b/liboctave/MArray.h	Thu Jun 14 16:38:06 2012 -0500
@@ -39,6 +39,12 @@
 class
 MArray : public Array<T>
 {
+protected:
+
+  // For jit support
+  MArray (T *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep)
+    : Array<T> (sdata, slen, adims, arep) { }
+
 public:
 
   MArray (void) : Array<T> () {}
--- a/liboctave/dNDArray.h	Mon Jun 11 20:11:20 2012 -0500
+++ b/liboctave/dNDArray.h	Thu Jun 14 16:38:06 2012 -0500
@@ -64,6 +64,10 @@
 
   NDArray (const charNDArray&);
 
+  // For jit support only
+  NDArray (double *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep)
+    : MArray<double> (sdata, slen, adims, arep) { }
+
   NDArray& operator = (const NDArray& a)
     {
       MArray<double>::operator = (a);
--- a/liboctave/dim-vector.h	Mon Jun 11 20:11:20 2012 -0500
+++ b/liboctave/dim-vector.h	Thu Jun 14 16:38:06 2012 -0500
@@ -212,6 +212,12 @@
 
   void chop_all_singletons (void);
 
+  // WARNING: Only call by jit
+  octave_idx_type *to_jit (void) const
+  {
+    return rep;
+  }
+
 private:
 
   static octave_idx_type *nil_rep (void)
@@ -220,9 +226,6 @@
       return zv.rep;
     }
 
-  explicit dim_vector (octave_idx_type *r)
-    : rep (r) { }
-
 public:
 
   static octave_idx_type dim_max (void);
@@ -233,6 +236,10 @@
   dim_vector (const dim_vector& dv) : rep (dv.rep)
   { OCTREFCOUNT_ATOMIC_INCREMENT (&(count())); }
 
+  // FIXME: Should be private, but required by array constructor for jit
+  explicit dim_vector (octave_idx_type *r)
+    : rep (r) { }
+
   static dim_vector alloc (int n)
   {
     return dim_vector (newrep (n < 2 ? 2 : n));
--- a/liboctave/oct-refcount.h	Mon Jun 11 20:11:20 2012 -0500
+++ b/liboctave/oct-refcount.h	Thu Jun 14 16:38:06 2012 -0500
@@ -82,6 +82,11 @@
       return static_cast<count_type const volatile&> (count);
     }
 
+  count_type *get (void)
+    {
+      return &count;
+    }
+
 private:
   count_type count;
 };
--- a/src/pt-jit.cc	Mon Jun 11 20:11:20 2012 -0500
+++ b/src/pt-jit.cc	Thu Jun 14 16:38:06 2012 -0500
@@ -23,6 +23,8 @@
 #define __STDC_LIMIT_MACROS
 #define __STDC_CONSTANT_MACROS
 
+#define OCTAVE_JIT_DEBUG
+
 #ifdef HAVE_CONFIG_H
 #include <config.h>
 #endif
@@ -147,6 +149,12 @@
   obv->release ();
 }
 
+extern "C" void
+octave_jit_delete_matrix (jit_matrix *m)
+{
+  NDArray array (*m);
+}
+
 extern "C" octave_base_value *
 octave_jit_grab_any (octave_base_value *obv)
 {
@@ -154,6 +162,25 @@
   return obv;
 }
 
+extern "C" octave_base_value *
+octave_jit_cast_any_matrix (jit_matrix *jmatrix)
+{
+  ++(*jmatrix->ref_count);
+  NDArray matrix = *jmatrix;
+  octave_value ret (matrix);
+
+  octave_base_value *rep = ret.internal_rep ();
+  rep->grab ();
+  return rep;
+}
+
+extern "C" void
+octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv)
+{
+  NDArray m = obv->array_value ();
+  *ret = m;
+}
+
 extern "C" double
 octave_jit_cast_scalar_any (octave_base_value *obv)
 {
@@ -190,6 +217,40 @@
   return obv;
 }
 
+extern "C" void
+octave_jit_ginvalid_index (void)
+{
+  try
+    {
+      gripe_invalid_index ();      
+    }
+  catch (const octave_execution_exception&)
+    {
+      gripe_library_execution_error ();
+    }
+}
+
+extern "C" void
+octave_jit_gindex_range (int nd, int dim, octave_idx_type iext,
+                         octave_idx_type ext)
+{
+  std::cout << "gindex_range\n";
+  try
+    {
+      gripe_index_out_of_range (nd, dim, iext, ext);
+    }
+  catch (const octave_execution_exception&)
+    {
+      gripe_library_execution_error ();
+    }
+}
+
+extern "C" void
+octave_jit_print_matrix (jit_matrix *m)
+{
+  std::cout << *m << std::endl;
+}
+
 // -------------------- jit_range --------------------
 std::ostream&
 operator<< (std::ostream& os, const jit_range& rng)
@@ -198,6 +259,16 @@
             << ", " << rng.nelem << "]";
 }
 
+// -------------------- jit_matrix --------------------
+
+std::ostream&
+operator<< (std::ostream& os, const jit_matrix& mat)
+{
+  return os << "Matrix[" << mat.ref_count << ", " << mat.slice_data << ", "
+            << mat.slice_len << ", " << mat.dimensions << ", "
+            << mat.array_rep << "]";
+}
+
 // -------------------- jit_type --------------------
 llvm::Type *
 jit_type::to_llvm_arg (void) const
@@ -291,34 +362,36 @@
   : module (m), engine (e), next_id (0)
 {
   // FIXME: We should be registering types like in octave_value_typeinfo
-  ov_t = llvm::StructType::create (context, "octave_base_value");
-  ov_t = ov_t->getPointerTo ();
-
-  llvm::Type *dbl = llvm::Type::getDoubleTy (context);
+  llvm::Type *any_t = llvm::StructType::create (context, "octave_base_value");
+  any_t = any_t->getPointerTo ();
+
+  llvm::Type *scalar_t = llvm::Type::getDoubleTy (context);
   llvm::Type *bool_t = llvm::Type::getInt1Ty (context);
   llvm::Type *string_t = llvm::Type::getInt8Ty (context);
   string_t = string_t->getPointerTo ();
-  llvm::Type *index_t = 0;
-  switch (sizeof(octave_idx_type))
-    {
-    case 4:
-      index_t = llvm::Type::getInt32Ty (context);
-      break;
-    case 8:
-      index_t = llvm::Type::getInt64Ty (context);
-      break;
-    default:
-      assert (false && "Unrecognized index type size");
-    }
+  llvm::Type *index_t = llvm::Type::getIntNTy (context, sizeof(octave_idx_type) * 8);
 
   llvm::StructType *range_t = llvm::StructType::create (context, "range");
-  std::vector<llvm::Type *> range_contents (4, dbl);
+  std::vector<llvm::Type *> range_contents (4, scalar_t);
   range_contents[3] = index_t;
   range_t->setBody (range_contents);
 
+  llvm::Type *refcount_t = llvm::Type::getIntNTy (context, sizeof(int) * 8);
+  llvm::Type *int_t = refcount_t;
+
+  llvm::StructType *matrix_t = llvm::StructType::create (context, "matrix");
+  llvm::Type *matrix_contents[5];
+  matrix_contents[0] = refcount_t->getPointerTo ();
+  matrix_contents[1] = scalar_t->getPointerTo ();
+  matrix_contents[2] = index_t;
+  matrix_contents[3] = index_t->getPointerTo ();
+  matrix_contents[4] = string_t;
+  matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5));
+
   // create types
-  any = new_type ("any", 0, ov_t);
-  scalar = new_type ("scalar", any, dbl);
+  any = new_type ("any", 0, any_t);
+  matrix = new_type ("matrix", any, matrix_t);
+  scalar = new_type ("scalar", any, scalar_t);
   range = new_type ("range", any, range_t);
   string = new_type ("string", any, string_t);
   boolean = new_type ("bool", any, bool_t);
@@ -378,6 +451,27 @@
   grab_fn.add_overload (fn, false, any, any);
   grab_fn.stash_name ("grab");
 
+  // grab matrix
+  llvm::Function *print_matrix = create_function ("octave_jit_print_matrix",
+                                                  void_t,
+                                                  matrix_t->getPointerTo ());
+  engine->addGlobalMapping (print_matrix, reinterpret_cast<void*>(&octave_jit_print_matrix));
+
+  fn = create_function ("octave_jit_grab_matrix", matrix, matrix);
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *one = llvm::ConstantInt::get (refcount_t, 1);
+
+    llvm::Value *mat = fn->arg_begin ();
+    llvm::Value *rcount= builder.CreateExtractValue (mat, 0);
+    llvm::Value *count = builder.CreateLoad (rcount);
+    count = builder.CreateAdd (count, one);
+    builder.CreateStore (count, rcount);
+    builder.CreateRet (mat);
+  }
+  grab_fn.add_overload (fn, false, matrix, matrix);
+
   // grab scalar
   fn = create_identity (scalar);
   grab_fn.add_overload (fn, false, scalar, scalar);
@@ -387,11 +481,45 @@
   grab_fn.add_overload (fn, false, index, index);
 
   // release any
-  fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ());
+  fn = create_function ("octave_jit_release_any", void_t, any_t);
   engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any));
   release_fn.add_overload (fn, false, 0, any);
   release_fn.stash_name ("release");
 
+  // release matrix
+  llvm::Function *delete_mat = create_function ("octave_jit_delete_matrix", void_t,
+                                                matrix_t);
+  engine->addGlobalMapping (delete_mat,
+                            reinterpret_cast<void*> (&octave_jit_delete_matrix));
+
+  fn = create_function ("octave_jit_release_matrix", void_t, matrix_t);
+  llvm::Function *release_mat = fn;
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *one = llvm::ConstantInt::get (refcount_t, 1);
+    llvm::Value *zero = llvm::ConstantInt::get (refcount_t, 0);
+
+    llvm::Value *mat = fn->arg_begin ();
+    llvm::Value *rcount= builder.CreateExtractValue (mat, 0);
+    llvm::Value *count = builder.CreateLoad (rcount);
+    count = builder.CreateSub (count, one);
+
+    llvm::BasicBlock *dead = llvm::BasicBlock::Create (context, "dead", fn);
+    llvm::BasicBlock *live = llvm::BasicBlock::Create (context, "live", fn);
+    llvm::Value *isdead = builder.CreateICmpEQ (count, zero);
+    builder.CreateCondBr (isdead, dead, live);
+
+    builder.SetInsertPoint (dead);
+    builder.CreateCall (delete_mat, mat);
+    builder.CreateRetVoid ();
+
+    builder.SetInsertPoint (live);
+    builder.CreateStore (count, rcount);
+    builder.CreateRetVoid ();
+  }
+  release_fn.add_overload (fn, false, 0, matrix);
+
   // release scalar
   fn = create_identity (scalar);
   release_fn.add_overload (fn, false, 0, scalar);
@@ -429,13 +557,13 @@
 
   // divide is annoying because it might error
   fn = create_function ("octave_jit_div_scalar_scalar", scalar, scalar, scalar);
-  llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
+  body = llvm::BasicBlock::Create (context, "body", fn);
   builder.SetInsertPoint (body);
   {
     llvm::BasicBlock *warn_block = llvm::BasicBlock::Create (context, "warn", fn);
     llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, "normal", fn);
 
-    llvm::Value *zero = llvm::ConstantFP::get (dbl, 0);
+    llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0);
     llvm::Value *check = builder.CreateFCmpUEQ (zero, ++fn->arg_begin ());
     builder.CreateCondBr (check, warn_block, normal_block);
 
@@ -514,7 +642,7 @@
   builder.SetInsertPoint (body);
   {
     llvm::Value *idx = ++fn->arg_begin ();
-    llvm::Value *didx = builder.CreateUIToFP (idx, dbl);
+    llvm::Value *didx = builder.CreateSIToFP (idx, scalar_t);
     llvm::Value *rng = fn->arg_begin ();
     llvm::Value *base = builder.CreateExtractValue (rng, 0);
     llvm::Value *inc = builder.CreateExtractValue (rng, 2);
@@ -548,7 +676,7 @@
     builder.CreateBr (normal_block);
     builder.SetInsertPoint (normal_block);
 
-    llvm::Value *zero = llvm::ConstantFP::get (dbl, 0);
+    llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0);
     llvm::Value *ret = builder.CreateFCmpONE (fn->arg_begin (), zero);
     builder.CreateRet (ret);
   }
@@ -580,7 +708,7 @@
     llvm::Value *inc = ++args;
     llvm::Value *nelem = builder.CreateCall3 (compute_nelem, base, limit, inc);
 
-    llvm::Value *dzero = llvm::ConstantFP::get (dbl, 0);
+    llvm::Value *dzero = llvm::ConstantFP::get (scalar_t, 0);
     llvm::Value *izero = llvm::ConstantInt::get (index_t, 0);
     llvm::Value *rng = llvm::ConstantStruct::get (range_t, dzero, dzero, dzero,
                                                   izero, NULL);
@@ -593,9 +721,110 @@
   llvm::verifyFunction (*fn);
   make_range_fn.add_overload (fn, false, range, scalar, scalar, scalar);
 
+  // paren_subsref
+  llvm::Function *ginvalid_index = create_function ("gipe_invalid_index", void_t);
+  engine->addGlobalMapping (ginvalid_index,
+                            reinterpret_cast<void*> (&octave_jit_ginvalid_index));
+
+  llvm::Function *gindex_range = create_function ("gripe_index_out_of_range",
+                                                  void_t, int_t, int_t, index_t,
+                                                  index_t);
+  engine->addGlobalMapping (gindex_range,
+                            reinterpret_cast<void*> (&octave_jit_gindex_range));
+
+  fn = create_function ("()subsref", scalar, matrix, scalar);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *one = llvm::ConstantInt::get (index_t, 1);
+    llvm::Value *ione;
+    if (index_t == int_t)
+      ione = one;
+    else
+      ione = llvm::ConstantInt::get (int_t, 1);
+            
+
+    llvm::Value *szero = llvm::ConstantFP::get (scalar_t, 0);
+
+    llvm::Function::arg_iterator args = fn->arg_begin ();
+    llvm::Value *mat = args++;
+    llvm::Value *idx = args;
+
+    // convert index to scalar to integer, and check index >= 1
+    llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t);
+    llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t);
+    llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx);
+    llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one);
+    llvm::Value *cond = builder.CreateOr (cond0, cond1);
+
+    llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn);
+
+    llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context,
+                                                             "conv_error", fn,
+                                                             done);
+    llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn,
+                                                         done);
+    builder.CreateCondBr (cond, conv_error, normal);
+
+    builder.SetInsertPoint (conv_error);
+    builder.CreateCall (ginvalid_index);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (normal);
+    llvm::Value *len = builder.CreateExtractValue (mat,
+                                                   llvm::ArrayRef<unsigned> (2));
+    cond = builder.CreateICmpSGT (int_idx, len);
+
+
+    llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context,
+                                                               "bounds_error",
+                                                               fn, done);
+
+    llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success",
+                                                          fn, done);
+    builder.CreateCondBr (cond, bounds_error, success);
+
+    builder.SetInsertPoint (bounds_error);
+    builder.CreateCall4 (gindex_range, ione, ione, int_idx, len);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (success);
+    llvm::Value *data = builder.CreateExtractValue (mat,
+                                                    llvm::ArrayRef<unsigned> (1));
+    llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx);
+    llvm::Value *ret = builder.CreateLoad (gep);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (done);
+
+    llvm::PHINode *merge = llvm::PHINode::Create (scalar_t, 3);
+    builder.Insert (merge);
+    merge->addIncoming (szero, conv_error);
+    merge->addIncoming (szero, bounds_error);
+    merge->addIncoming (ret, success);
+    builder.CreateCall (release_mat, mat);
+    builder.CreateRet (merge);
+  }
+  llvm::verifyFunction (*fn);
+  paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar);
+
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
 
+  // cast any <- matrix
+  fn = create_function ("octave_jit_cast_any_matrix", any_t,
+                        matrix_t->getPointerTo ());
+  engine->addGlobalMapping (fn,
+                            reinterpret_cast<void*> (&octave_jit_cast_any_matrix));
+  casts[any->type_id ()].add_overload (fn, false, any, matrix);
+
+  // cast matrix <- any
+  fn = create_function ("octave_jit_cast_matrix_any", void_t,
+                        matrix_t->getPointerTo (), any_t);
+  engine->addGlobalMapping (fn,
+                            reinterpret_cast<void*> (&octave_jit_cast_matrix_any));
+  casts[matrix->type_id ()].add_overload (fn, false, matrix, any);
+
   // cast any <- scalar
   fn = create_function ("octave_jit_cast_any_scalar", any, scalar);
   engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_scalar));
@@ -740,14 +969,20 @@
 jit_typeinfo::do_type_of (const octave_value &ov) const
 {
   if (ov.is_function ())
-    return 0;
-
-  if (ov.is_double_type () && ov.is_real_scalar ())
-    return get_scalar ();
+    return 0; // functions are not supported
 
   if (ov.is_range ())
     return get_range ();
 
+  if (ov.is_double_type ())
+    {
+      if (ov.is_real_scalar ())
+        return get_scalar ();
+
+      if (ov.is_matrix_type ())
+        return get_matrix ();
+    }
+
   return get_any ();
 }
 
@@ -1345,7 +1580,7 @@
     if (jit_extract_argument *extract = dynamic_cast<jit_extract_argument *> (*iter))
       arguments.push_back (std::make_pair (extract->name (), true));
 
-  convert_llvm to_llvm;
+  convert_llvm to_llvm (*this);
   function = to_llvm.convert (module, arguments, blocks, constants);
 
 #ifdef OCTAVE_JIT_DEBUG
@@ -1686,9 +1921,34 @@
 }
 
 void
-jit_convert::visit_index_expression (tree_index_expression&)
+jit_convert::visit_index_expression (tree_index_expression& exp)
 {
-  fail ();
+  std::string type = exp.type_tags ();
+  if (! (type.size () == 1 && type[0] == '('))
+    fail ("Unsupported index operation");
+
+  std::list<tree_argument_list *> args = exp.arg_lists ();
+  if (args.size () != 1)
+    fail ("Bad number of arguments in tree_index_expression");
+
+  tree_argument_list *arg_list = args.front ();
+  if (arg_list->size () != 1)
+    fail ("Bad number of arguments in arg_list");
+
+  tree_expression *tree_object = exp.expression ();
+  jit_value *object = visit (tree_object);
+
+  tree_expression *arg0 = arg_list->front ();
+  jit_value *index = visit (arg0);
+
+  jit_call *call = create<jit_call> (jit_typeinfo::paren_subsref, object, index);
+  block->append (call);
+
+  jit_block *normal = create<jit_block> (block->name ());
+  block->append (create<jit_check_error> (call, normal, final_block));
+  add_block (normal);
+  block = normal;
+  result = call;
 }
 
 void
@@ -2286,7 +2546,7 @@
                   fail (ss.str ());
                 }
 
-              builder.CreateCall (ol.function, phi->argument_llvm (i));
+              create_call (ol, phi->argument (i));
             }
         }
     }
@@ -2305,17 +2565,8 @@
               const jit_function::overload& ol
                 = jit_typeinfo::cast (phi->type (),
                                       phi->argument_type (i));
-              if (! ol.function)
-                {
-                  std::stringstream ss;
-                  ss << "No cast for phi(" << i << "): ";
-                  phi->print (ss);
-                  fail (ss.str ());
-                }
-
-              llvm::Value *casted;
-              casted = builder.CreateCall (ol.function,
-                                           phi->argument_llvm (i));
+
+              llvm::Value *casted = create_call (ol, phi->argument (i));
               llvm_phi->addIncoming (casted, pred);
             }
         }
@@ -2343,14 +2594,14 @@
 jit_convert::convert_llvm::visit (jit_const_range& cr)
 {
   llvm::StructType *stype = llvm::cast<llvm::StructType>(cr.type_llvm ());
-  llvm::Type *dbl = jit_typeinfo::get_scalar_llvm ();
+  llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm ();
   llvm::Type *idx = jit_typeinfo::get_index_llvm ();
   const jit_range& rng = cr.value ();
 
   llvm::Constant *constants[4];
-  constants[0] = llvm::ConstantFP::get (dbl, rng.base);
-  constants[1] = llvm::ConstantFP::get (dbl, rng.limit);
-  constants[2] = llvm::ConstantFP::get (dbl, rng.inc);
+  constants[0] = llvm::ConstantFP::get (scalar_t, rng.base);
+  constants[1] = llvm::ConstantFP::get (scalar_t, rng.limit);
+  constants[2] = llvm::ConstantFP::get (scalar_t, rng.inc);
   constants[3] = llvm::ConstantInt::get (idx, rng.nelem);
 
   llvm::Value *as_llvm;
@@ -2386,39 +2637,25 @@
 void
 jit_convert::convert_llvm::visit (jit_call& call)
 {
-  const jit_function::overload& ol = call.overload ();
-  if (! ol.function)
-    fail ("No overload for: " + call.print_string ());
-
-  std::vector<llvm::Value *> args (call.argument_count ());
-  for (size_t i = 0; i < call.argument_count (); ++i)
-    args[i] = call.argument_llvm (i);
-
-  call.stash_llvm (builder.CreateCall (ol.function, args));
+  llvm::Value *ret = create_call (call.overload (), call.arguments ());
+  call.stash_llvm (ret);
 }
 
 void
 jit_convert::convert_llvm::visit (jit_extract_argument& extract)
 {
-  const jit_function::overload& ol = extract.overload ();
-  if (! ol.function)
-    fail ();
-
   llvm::Value *arg = arguments[extract.name ()];
   assert (arg);
   arg = builder.CreateLoad (arg);
-  extract.stash_llvm (builder.CreateCall (ol.function, arg, extract.name ()));
+
+  jit_value *jarg = jthis.create<jit_argument> (jit_typeinfo::get_any (), arg);
+  extract.stash_llvm (create_call (extract.overload (), jarg));
 }
 
 void
 jit_convert::convert_llvm::visit (jit_store_argument& store)
 {
-  llvm::Value *arg_value = store.result_llvm ();
-  const jit_function::overload& ol = store.overload ();
-  if (! ol.function)
-    fail ();
-
-  arg_value = builder.CreateCall (ol.function, arg_value);
+  llvm::Value *arg_value = create_call (store.overload (), store.result ());
 
   llvm::Value *arg = arguments[store.name ()];
   store.stash_llvm (builder.CreateStore (arg_value, arg));
@@ -2463,6 +2700,69 @@
 jit_convert::convert_llvm::visit (jit_assign&)
 {}
 
+void
+jit_convert::convert_llvm::visit (jit_argument&)
+{}
+
+llvm::Value *
+jit_convert::convert_llvm::create_call (const jit_function::overload& ol,
+                                        const std::vector<jit_value *>& jargs)
+{
+   llvm::Function *fun = ol.function;
+   if (! fun)
+     fail ("Missing overload");
+
+  const llvm::Function::ArgumentListType& alist = fun->getArgumentList ();
+  size_t nargs = alist.size ();
+  bool sret = false;
+  if (nargs != jargs.size ())
+    {
+      // first argument is the structure return value
+      assert (nargs == jargs.size () + 1);
+      sret = true;
+    }
+
+  std::vector<llvm::Value *> args (nargs);
+  llvm::Function::arg_iterator llvm_arg = fun->arg_begin ();
+  if (sret)
+    {
+      args[0] = builder.CreateAlloca (ol.result->to_llvm ());
+      ++llvm_arg;
+    }
+
+  for (size_t i = 0; i < jargs.size (); ++i, ++llvm_arg)
+    {
+      llvm::Value *arg = jargs[i]->to_llvm ();
+      llvm::Type *arg_type = arg->getType ();
+      llvm::Type *llvm_arg_type = llvm_arg->getType ();
+
+      if (arg_type == llvm_arg_type)
+        args[i + sret] = arg;
+      else
+        {
+          // pass structure by pointer
+          assert (arg_type->getPointerTo () == llvm_arg_type);
+          llvm::Value *new_arg = builder.CreateAlloca (arg_type);
+          builder.CreateStore (arg, new_arg);
+          args[i + sret] = new_arg;
+        }
+    }
+
+  llvm::Value *llvm_call = builder.CreateCall (fun, args);
+  return sret ? builder.CreateLoad (args[0]) : llvm_call;
+}
+
+llvm::Value *
+jit_convert::convert_llvm::create_call (const jit_function::overload& ol,
+                                        const std::vector<jit_use>& uses)
+{
+  std::vector<jit_value *> values (uses.size ());
+  for (size_t i = 0; i < uses.size (); ++i)
+    values[i] = uses[i].value ();
+
+  return create_call (ol, values);
+}
+
 // -------------------- tree_jit --------------------
 
 tree_jit::tree_jit (void) : module (0), engine (0)
--- a/src/pt-jit.h	Mon Jun 11 20:11:20 2012 -0500
+++ b/src/pt-jit.h	Thu Jun 14 16:38:06 2012 -0500
@@ -219,6 +219,43 @@
 
 std::ostream& operator<< (std::ostream& os, const jit_range& rng);
 
+// jit_array is compatable with the llvm array/matrix structures
+template <typename T, typename U>
+struct
+jit_array
+{
+  jit_array (T& from) : ref_count (from.jit_ref_count ()),
+                        slice_data (from.jit_slice_data () - 1),
+                        slice_len (from.capacity ()),
+                        dimensions (from.jit_dimensions ()),
+                        array_rep (from.jit_array_rep ())
+  {
+    grab_dimensions ();
+  }
+
+  void grab_dimensions (void)
+  {
+    ++(dimensions[-2]);
+  }
+
+  operator T () const
+  {
+    return T (slice_data + 1, slice_len, dimensions, array_rep);
+  }
+
+  int *ref_count;
+
+  U *slice_data;
+  octave_idx_type slice_len;
+  octave_idx_type *dimensions;
+
+  void *array_rep;
+};
+
+typedef jit_array<NDArray, double> jit_matrix;
+
+std::ostream& operator<< (std::ostream& os, const jit_matrix& mat);
+
 // Used to keep track of estimated (infered) types during JIT. This is a
 // hierarchical type system which includes both concrete and abstract types.
 //
@@ -384,6 +421,8 @@
 
   static jit_type *get_any (void) { return instance->any; }
 
+  static jit_type *get_matrix (void) { return instance->matrix; }
+
   static jit_type *get_scalar (void) { return instance->scalar; }
 
   static llvm::Type *get_scalar_llvm (void) { return instance->scalar->to_llvm (); }
@@ -445,6 +484,11 @@
     return instance->make_range_fn;
   }
 
+  static const jit_function& paren_subsref (void)
+  {
+    return instance->paren_subsref_fn;
+  }
+
   static const jit_function& logically_true (void)
   {
     return instance->logically_true_fn;
@@ -597,6 +641,18 @@
   }
 
   llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret,
+                                   llvm::Type *arg0, llvm::Type *arg1,
+                                   llvm::Type *arg2, llvm::Type *arg3)
+  {
+    std::vector<llvm::Type *> args (4);
+    args[0] = arg0;
+    args[1] = arg1;
+    args[2] = arg2;
+    args[3] = arg3;
+    return create_function (name, ret, args);
+  }
+
+  llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret,
                                    const std::vector<llvm::Type *>& args);
 
   llvm::Function *create_identity (jit_type *type);
@@ -609,11 +665,11 @@
   llvm::ExecutionEngine *engine;
   int next_id;
 
-  llvm::Type *ov_t;
   llvm::GlobalVariable *lerror_state;
 
   std::vector<jit_type*> id_to_type;
   jit_type *any;
+  jit_type *matrix;
   jit_type *scalar;
   jit_type *range;
   jit_type *string;
@@ -629,6 +685,7 @@
   jit_function for_index_fn;
   jit_function logically_true_fn;
   jit_function make_range_fn;
+  jit_function paren_subsref_fn;
 
   // type id -> cast function TO that type
   std::vector<jit_function> casts;
@@ -651,7 +708,8 @@
   JIT_METH(phi);                                \
   JIT_METH(variable);                           \
   JIT_METH(check_error);                        \
-  JIT_METH(assign)
+  JIT_METH(assign)                              \
+  JIT_METH(argument)
 
 #define JIT_VISIT_IR_CONST                      \
   JIT_METH(const_scalar);                       \
@@ -830,18 +888,18 @@
     : already_infered (nargs, reinterpret_cast<jit_type *>(0)),
       mid (next_id ()), mparent (0)
   {
-    arguments.reserve (nargs);
+    marguments.reserve (nargs);
   }
 
   jit_instruction (jit_value *arg0)
-    : already_infered (1, reinterpret_cast<jit_type *>(0)), arguments (1), 
+    : already_infered (1, reinterpret_cast<jit_type *>(0)), marguments (1), 
       mid (next_id ()), mparent (0)
   {
     stash_argument (0, arg0);
   }
 
   jit_instruction (jit_value *arg0, jit_value *arg1)
-    : already_infered (2, reinterpret_cast<jit_type *>(0)), arguments (2), 
+    : already_infered (2, reinterpret_cast<jit_type *>(0)), marguments (2), 
       mid (next_id ()), mparent (0)
   {
     stash_argument (0, arg0);
@@ -849,7 +907,7 @@
   }
 
   jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2)
-    : already_infered (3, reinterpret_cast<jit_type *>(0)), arguments (3), 
+    : already_infered (3, reinterpret_cast<jit_type *>(0)), marguments (3), 
       mid (next_id ()), mparent (0)
   {
     stash_argument (0, arg0);
@@ -859,7 +917,7 @@
 
   jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2,
                    jit_value *arg3)
-    : already_infered (3, reinterpret_cast<jit_type *>(0)), arguments (4), 
+    : already_infered (3, reinterpret_cast<jit_type *>(0)), marguments (4), 
       mid (next_id ()), mparent (0)
   {
     stash_argument (0, arg0);
@@ -875,7 +933,7 @@
 
   jit_value *argument (size_t i) const
   {
-    return arguments[i].value ();
+    return marguments[i].value ();
   }
 
   llvm::Value *argument_llvm (size_t i) const
@@ -905,25 +963,25 @@
 
   void stash_argument (size_t i, jit_value *arg)
   {
-    arguments[i].stash_value (arg, this, i);
+    marguments[i].stash_value (arg, this, i);
   }
 
   void push_argument (jit_value *arg)
   {
-    arguments.push_back (jit_use ());
-    stash_argument (arguments.size () - 1, arg);
+    marguments.push_back (jit_use ());
+    stash_argument (marguments.size () - 1, arg);
     already_infered.push_back (0);
   }
 
   size_t argument_count (void) const
   {
-    return arguments.size ();
+    return marguments.size ();
   }
 
   void resize_arguments (size_t acount, jit_value *adefault = 0)
   {
-    size_t old = arguments.size ();
-    arguments.resize (acount);
+    size_t old = marguments.size ();
+    marguments.resize (acount);
     already_infered.resize (acount);
 
     if (adefault)
@@ -931,6 +989,8 @@
         stash_argument (i, adefault);
   }
 
+  const std::vector<jit_use>& arguments (void) const { return marguments; }
+
   // argument types which have been infered already
   const std::vector<jit_type *>& argument_types (void) const
   { return already_infered; }
@@ -974,7 +1034,7 @@
     return ret++;
   }
 
-  std::vector<jit_use> arguments;
+  std::vector<jit_use> marguments;
 
   size_t mid;
   jit_block *mparent;
@@ -982,9 +1042,29 @@
 };
 
 // defnie accept methods for subclasses
-#define JIT_VALUE_ACCEPT(clname)                \
+#define JIT_VALUE_ACCEPT                        \
   virtual void accept (jit_ir_walker& walker);
 
+// for use as a dummy argument during conversion to LLVM
+class
+jit_argument : public jit_value
+{
+public:
+  jit_argument (jit_type *atype, llvm::Value *avalue)
+  {
+    stash_type (atype);
+    stash_llvm (avalue);
+  }
+
+  virtual std::ostream& print (std::ostream& os, size_t indent = 0) const
+  {
+    print_indent (os, indent);
+    return jit_print (os, type ()) << ": DUMMY";
+  }
+
+  JIT_VALUE_ACCEPT;
+};
+
 template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T,
           bool QUOTE>
 class
@@ -1012,7 +1092,7 @@
     return os;
   }
 
-  JIT_VALUE_ACCEPT (jit_const);
+  JIT_VALUE_ACCEPT;
 private:
   T mvalue;
 };
@@ -1212,7 +1292,7 @@
   void stash_location (std::list<jit_block *>::iterator alocation)
   { mlocation = alocation; }
 
-  JIT_VALUE_ACCEPT (block);
+  JIT_VALUE_ACCEPT;
 private:
   void internal_append (jit_instruction *instr);
 
@@ -1370,7 +1450,7 @@
     return print_indent (os, indent) << mname;
   }
 
-  JIT_VALUE_ACCEPT (variable)
+  JIT_VALUE_ACCEPT;
 private:
   std::string mname;
   std::stack<jit_value *> value_stack;
@@ -1426,7 +1506,7 @@
     return print_indent (os, indent) << *dest () << " = " << *src ();
   }
 
-  JIT_VALUE_ACCEPT (assign);
+  JIT_VALUE_ACCEPT;
 private:
   jit_variable *mdest;
 };
@@ -1498,7 +1578,7 @@
     return os << "#" << id ();
   }
 
-  JIT_VALUE_ACCEPT (phi);
+  JIT_VALUE_ACCEPT;
 private:
   std::vector<jit_phi_incomming> mincomming;
 };
@@ -1597,7 +1677,7 @@
     return print_successor (os);
   }
 
-  JIT_VALUE_ACCEPT (break)
+  JIT_VALUE_ACCEPT;
 };
 
 class
@@ -1629,7 +1709,7 @@
     return print_successor (os, 1);
   }
 
-  JIT_VALUE_ACCEPT (cond_break)
+  JIT_VALUE_ACCEPT;
 };
 
 class
@@ -1691,7 +1771,7 @@
 
   virtual bool infer (void);
 
-  JIT_VALUE_ACCEPT (call)
+  JIT_VALUE_ACCEPT;
 private:
   const jit_function& mfunction;
 };
@@ -1718,7 +1798,7 @@
     return print_successor (os, 0);
   }
 
-  JIT_VALUE_ACCEPT (jit_check_error)
+  JIT_VALUE_ACCEPT;
 protected:
   virtual bool check_alive (size_t idx) const
   {
@@ -1753,7 +1833,7 @@
     return short_print (os) << " = extract " << name ();
   }
 
-  JIT_VALUE_ACCEPT (extract_argument)
+  JIT_VALUE_ACCEPT;
 };
 
 class
@@ -1804,7 +1884,7 @@
     return os;
   }
 
-  JIT_VALUE_ACCEPT (store_argument)
+  JIT_VALUE_ACCEPT;
 private:
   jit_variable *dest;
 };
@@ -2103,6 +2183,8 @@
   convert_llvm : public jit_ir_walker
   {
   public:
+    convert_llvm (jit_convert& jc) : jthis (jc) {}
+
     llvm::Function *convert (llvm::Module *module,
                              const std::vector<std::pair<std::string, bool> >& args,
                              const std::list<jit_block *>& blocks,
@@ -2129,7 +2211,30 @@
     {
       jvalue.accept (*this);
     }
+
+    llvm::Value *create_call (const jit_function::overload& ol, jit_value *arg0)
+    {
+      std::vector<jit_value *> args (1, arg0);
+      return create_call (ol, args);
+    }
+
+    llvm::Value *create_call (const jit_function::overload& ol, jit_value *arg0,
+                              jit_value *arg1)
+    {
+      std::vector<jit_value *> args (2);
+      args[0] = arg0;
+      args[1] = arg1;
+
+      return create_call (ol, args);
+    }
+
+    llvm::Value *create_call (const jit_function::overload& ol,
+                              const std::vector<jit_value *>& jargs);
+
+    llvm::Value *create_call (const jit_function::overload& ol,
+                              const std::vector<jit_use>& uses);
   private:
+    jit_convert &jthis;
     llvm::Function *function;
   };
 };