changeset 14967:0cfe0cf55a02

Simplify matrix handling in JIT * src/pt-jit.cc (octave_jit_release_matrix, octave_jit_grab_matrix): New function. (octave_jit_delete_matrix): Removed function. (octave_jit_cast_any_matrix, octave_jit_print_matrix): Use new jit_matrix layout. (jit_typeinfo::jit_typeing): Removed identity overload for grab/release and do not release matrix on subsref. * src/pt-jit.h (jit_matrix::jit_matrix): Initialize NDArray field. (jit_matrix::operator T): Return NDArray directly.
author Max Brister <max@2bass.com>
date Fri, 22 Jun 2012 15:46:26 -0500
parents b7b647bc4b90
children 7f60cdfcc0e5
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 28 insertions(+), 70 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc	Fri Jun 22 11:43:30 2012 -0500
+++ b/src/pt-jit.cc	Fri Jun 22 15:46:26 2012 -0500
@@ -147,9 +147,9 @@
 }
 
 extern "C" void
-octave_jit_delete_matrix (jit_matrix *m)
+octave_jit_release_matrix (jit_matrix *m)
 {
-  NDArray array (*m);
+  delete m->array;
 }
 
 extern "C" octave_base_value *
@@ -159,15 +159,19 @@
   return obv;
 }
 
-extern "C" octave_base_value *
-octave_jit_cast_any_matrix (jit_matrix *jmatrix)
+extern "C" void
+octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m)
 {
-  ++(*jmatrix->ref_count);
-  NDArray matrix = *jmatrix;
-  octave_value ret (matrix);
-
+  *result = *m->array;
+}
+
+extern "C" octave_base_value *
+octave_jit_cast_any_matrix (jit_matrix *m)
+{
+  octave_value ret (*m->array);
   octave_base_value *rep = ret.internal_rep ();
   rep->grab ();
+
   return rep;
 }
 
@@ -253,7 +257,7 @@
 {
   return os << "Matrix[" << mat.ref_count << ", " << mat.slice_data << ", "
             << mat.slice_len << ", " << mat.dimensions << ", "
-            << mat.array_rep << "]";
+            << mat.array << "]";
 }
 
 // -------------------- jit_type --------------------
@@ -445,30 +449,12 @@
                                                   matrix_t->getPointerTo ());
   engine->addGlobalMapping (print_matrix,
                             reinterpret_cast<void*>(&octave_jit_print_matrix));
-
-  fn = create_function ("octave_jit_grab_matrix", matrix, matrix);
-  llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
-  builder.SetInsertPoint (body);
-  {
-    llvm::Value *one = llvm::ConstantInt::get (refcount_t, 1);
-
-    llvm::Value *mat = fn->arg_begin ();
-    llvm::Value *rcount= builder.CreateExtractValue (mat, 0);
-    llvm::Value *count = builder.CreateLoad (rcount);
-    count = builder.CreateAdd (count, one);
-    builder.CreateStore (count, rcount);
-    builder.CreateRet (mat);
-  }
+  fn = create_function ("octave_jit_grab_matrix", void_t,
+                        matrix_t->getPointerTo (), matrix_t->getPointerTo ());
+  engine->addGlobalMapping (fn,
+                            reinterpret_cast<void *> (&octave_jit_grab_matrix));
   grab_fn.add_overload (fn, false, matrix, matrix);
 
-  // grab scalar
-  fn = create_identity (scalar);
-  grab_fn.add_overload (fn, false, scalar, scalar);
-
-  // grab index
-  fn = create_identity (index);
-  grab_fn.add_overload (fn, false, index, index);
-
   // release any
   fn = create_function ("octave_jit_release_any", void_t, any_t);
   engine->addGlobalMapping (fn,
@@ -477,37 +463,10 @@
   release_fn.stash_name ("release");
 
   // release matrix
-  llvm::Function *delete_mat = create_function ("octave_jit_delete_matrix",
-                                                void_t, matrix_t);
-  engine->addGlobalMapping (delete_mat,
-                            reinterpret_cast<void*> (&octave_jit_delete_matrix));
-
-  fn = create_function ("octave_jit_release_matrix", void_t, matrix_t);
-  llvm::Function *release_mat = fn;
-  body = llvm::BasicBlock::Create (context, "body", fn);
-  builder.SetInsertPoint (body);
-  {
-    llvm::Value *one = llvm::ConstantInt::get (refcount_t, 1);
-    llvm::Value *zero = llvm::ConstantInt::get (refcount_t, 0);
-
-    llvm::Value *mat = fn->arg_begin ();
-    llvm::Value *rcount= builder.CreateExtractValue (mat, 0);
-    llvm::Value *count = builder.CreateLoad (rcount);
-    count = builder.CreateSub (count, one);
-
-    llvm::BasicBlock *dead = llvm::BasicBlock::Create (context, "dead", fn);
-    llvm::BasicBlock *live = llvm::BasicBlock::Create (context, "live", fn);
-    llvm::Value *isdead = builder.CreateICmpEQ (count, zero);
-    builder.CreateCondBr (isdead, dead, live);
-
-    builder.SetInsertPoint (dead);
-    builder.CreateCall (delete_mat, mat);
-    builder.CreateRetVoid ();
-
-    builder.SetInsertPoint (live);
-    builder.CreateStore (count, rcount);
-    builder.CreateRetVoid ();
-  }
+  fn = create_function ("octave_jit_release_matrix", void_t,
+                        matrix_t->getPointerTo ());
+  engine->addGlobalMapping (fn,
+                            reinterpret_cast<void *> (&octave_jit_release_matrix));
   release_fn.add_overload (fn, false, 0, matrix);
 
   // release scalar
@@ -538,7 +497,7 @@
 
   // divide is annoying because it might error
   fn = create_function ("octave_jit_div_scalar_scalar", scalar, scalar, scalar);
-  body = llvm::BasicBlock::Create (context, "body", fn);
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
   builder.SetInsertPoint (body);
   {
     llvm::BasicBlock *warn_block = llvm::BasicBlock::Create (context, "warn",
@@ -733,7 +692,7 @@
     else
       ione = llvm::ConstantInt::get (int_t, 1);
 
-    llvm::Value *szero = llvm::ConstantFP::get (scalar_t, 0);
+    llvm::Value *undef = llvm::UndefValue::get (scalar_t);
 
     llvm::Function::arg_iterator args = fn->arg_begin ();
     llvm::Value *mat = args++;
@@ -788,10 +747,9 @@
 
     llvm::PHINode *merge = llvm::PHINode::Create (scalar_t, 3);
     builder.Insert (merge);
-    merge->addIncoming (szero, conv_error);
-    merge->addIncoming (szero, bounds_error);
+    merge->addIncoming (undef, conv_error);
+    merge->addIncoming (undef, bounds_error);
     merge->addIncoming (ret, success);
-    builder.CreateCall (release_mat, mat);
     builder.CreateRet (merge);
   }
   llvm::verifyFunction (*fn);
--- a/src/pt-jit.h	Fri Jun 22 11:43:30 2012 -0500
+++ b/src/pt-jit.h	Fri Jun 22 15:46:26 2012 -0500
@@ -230,7 +230,7 @@
                         slice_data (from.jit_slice_data () - 1),
                         slice_len (from.capacity ()),
                         dimensions (from.jit_dimensions ()),
-                        array_rep (from.jit_array_rep ())
+                        array (new T (from))
   {
     grab_dimensions ();
   }
@@ -242,7 +242,7 @@
 
   operator T () const
   {
-    return T (slice_data + 1, slice_len, dimensions, array_rep);
+    return *array;
   }
 
   int *ref_count;
@@ -251,7 +251,7 @@
   octave_idx_type slice_len;
   octave_idx_type *dimensions;
 
-  void *array_rep;
+  T *array;
 };
 
 typedef jit_array<NDArray, double> jit_matrix;