diff src/pt-jit.cc @ 14969:bbeef7b8ea2e

Add support for matrix indexed assignment to JIT * src/pt-jit.cc (octave_jit_subsasgn_impl, jit_convert::resolve): New function. (jit_typeinfo::jit_typeinfo): Add subsasgn implementation in llvm. (jit_convert::visit_simple_for_command): Use new do_assign overload. (jit_convert::visit_index_expression): Use new do_assign overload and resolve. (jit_convert::visit_simple_assignment): Use new do_assign overload. (jit_convert::do_assign): New overload. (jit_convert::convert_llvm::visit): Check if assignment is artificial. * src/pt-jit.h (jit_typeinfo::paren_subsasgn, jit_convert::create_check): New function. (jit_assign::jit_assign): Initialize martificial. (jit_assign::artificial, jit_assign::mark_artificial): New function. (jit_assign::print): Print the artificial flag. (jit_convert::create_checked_impl): Call create_check. (jit_convert::resolve): New declaration. (jit_convert::do_assign): New overload declaration.
author Max Brister <max@2bass.com>
date Mon, 25 Jun 2012 14:21:45 -0500
parents 7f60cdfcc0e5
children b23a98ca0e43
line wrap: on
line diff
--- a/src/pt-jit.cc	Fri Jun 22 17:17:48 2012 -0500
+++ b/src/pt-jit.cc	Mon Jun 25 14:21:45 2012 -0500
@@ -237,6 +237,24 @@
 }
 
 extern "C" void
+octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index,
+                                double value)
+{
+  std::cout << "impl\n";
+  NDArray *array = mat->array;
+  if (array->nelem () < index)
+    array->resize1 (index);
+
+  double *data = array->fortran_vec ();
+  data[index - 1] = value;
+
+  mat->ref_count = array->jit_ref_count ();
+  mat->slice_data = array->jit_slice_data () - 1;
+  mat->dimensions = array->jit_dimensions ();
+  mat->slice_len = array->nelem ();
+}
+
+extern "C" void
 octave_jit_print_matrix (jit_matrix *m)
 {
   std::cout << *m << std::endl;
@@ -755,6 +773,92 @@
   llvm::verifyFunction (*fn);
   paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar);
 
+  // paren subsasgn
+  paren_subsasgn_fn.stash_name ("()subsasgn");
+
+  llvm::Function *resize_paren_subsasgn
+    = create_function ("octave_jit_paren_subsasgn_impl", void_t,
+                       matrix_t->getPointerTo (), index_t, scalar_t);
+  engine->addGlobalMapping (resize_paren_subsasgn,
+                            reinterpret_cast<void *> (&octave_jit_paren_subsasgn_impl));
+
+  fn = create_function ("octave_jit_paren_subsasgn", matrix, matrix, scalar,
+                        scalar);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *one = llvm::ConstantInt::get (index_t, 1);
+
+    llvm::Function::arg_iterator args = fn->arg_begin ();
+    llvm::Value *mat = args++;
+    llvm::Value *idx = args++;
+    llvm::Value *value = args;
+
+    llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t);
+    llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t);
+    llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx);
+    llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one);
+    llvm::Value *cond = builder.CreateOr (cond0, cond1);
+
+    llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn);
+
+    llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context,
+                                                             "conv_error", fn,
+                                                             done);
+    llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn,
+                                                         done);
+    builder.CreateCondBr (cond, conv_error, normal);
+    builder.SetInsertPoint (conv_error);
+    builder.CreateCall (ginvalid_index);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (normal);
+    llvm::Value *len = builder.CreateExtractValue (mat,
+                                                   llvm::ArrayRef<unsigned> (2));
+    cond0 = builder.CreateICmpSGT (int_idx, len);
+
+    llvm::Value *rcount = builder.CreateExtractValue (mat, 0);
+    rcount = builder.CreateLoad (rcount);
+    cond1 = builder.CreateICmpSGT (rcount, one);
+    cond = builder.CreateOr (cond0, cond1);
+
+    llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context,
+                                                               "bounds_error",
+                                                               fn, done);
+
+    llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success",
+                                                          fn, done);
+    builder.CreateCondBr (cond, bounds_error, success);
+
+    // resize on out of bounds access
+    builder.SetInsertPoint (bounds_error);
+    llvm::Value *resize_result = builder.CreateAlloca (matrix_t);
+    builder.CreateStore (mat, resize_result);
+    builder.CreateCall3 (resize_paren_subsasgn, resize_result, int_idx, value);
+    resize_result = builder.CreateLoad (resize_result);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (success);
+    llvm::Value *data = builder.CreateExtractValue (mat,
+                                                    llvm::ArrayRef<unsigned> (1));
+    llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx);
+    builder.CreateStore (value, gep);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (done);
+
+    llvm::PHINode *merge = llvm::PHINode::Create (matrix_t, 3);
+    builder.Insert (merge);
+    merge->addIncoming (mat, conv_error);
+    merge->addIncoming (resize_result, bounds_error);
+    merge->addIncoming (mat, success);
+    builder.CreateRet (merge);
+  }
+  llvm::verifyFunction (*fn);
+  paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, scalar, scalar);
+
+  // paren_subsasgn
+
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
 
@@ -1689,12 +1793,6 @@
   prot.protect_var (breaking);
   breaks.clear ();
 
-  // FIXME: one of these days we will introduce proper lvalues...
-  tree_identifier *lhs = dynamic_cast<tree_identifier *>(cmd.left_hand_side ());
-  if (! lhs)
-    fail ();
-  std::string lhs_name = lhs->name ();
-
   // we need a variable for our iterator, because it is used in multiple blocks
   std::stringstream ss;
   ss << "#iter" << iterator_count++;
@@ -1719,9 +1817,10 @@
   block = body;
 
   // compute the syntactical iterator
-  jit_call *idx_rhs = create<jit_call> (jit_typeinfo::for_index, control, iterator);
+  jit_call *idx_rhs = create<jit_call> (jit_typeinfo::for_index, control,
+                                        iterator);
   block->append (idx_rhs);
-  do_assign (lhs_name, idx_rhs, false);
+  do_assign (cmd.left_hand_side (), idx_rhs);
 
   // do loop
   tree_statement_list *pt_body = cmd.body ();
@@ -1901,26 +2000,9 @@
 void
 jit_convert::visit_index_expression (tree_index_expression& exp)
 {
-  std::string type = exp.type_tags ();
-  if (! (type.size () == 1 && type[0] == '('))
-    fail ("Unsupported index operation");
-
-  std::list<tree_argument_list *> args = exp.arg_lists ();
-  if (args.size () != 1)
-    fail ("Bad number of arguments in tree_index_expression");
-
-  tree_argument_list *arg_list = args.front ();
-  if (! arg_list)
-    fail ("null argument list");
-
-  if (arg_list->size () != 1)
-    fail ("Bad number of arguments in arg_list");
-
-  tree_expression *tree_object = exp.expression ();
-  jit_value *object = visit (tree_object);
-
-  tree_expression *arg0 = arg_list->front ();
-  jit_value *index = visit (arg0);
+  std::pair<jit_value *, jit_value *> res = resolve (exp);
+  jit_value *object = res.first;
+  jit_value *index = res.second;
 
   result = create_checked (jit_typeinfo::paren_subsref, object, index);
 }
@@ -2013,13 +2095,7 @@
   tree_expression *rhs = tsa.right_hand_side ();
   jit_value *rhsv = visit (rhs);
 
-  // resolve lhs
-  tree_expression *lhs = tsa.left_hand_side ();
-  if (! lhs->is_identifier ())
-    fail ();
-
-  std::string lhs_name = lhs->name ();
-  result = do_assign (lhs_name, rhsv, tsa.print_result ());
+  do_assign (tsa.left_hand_side (), rhsv);
 }
 
 void
@@ -2156,12 +2232,68 @@
   return vmap[vname] = var;
 }
 
+std::pair<jit_value *, jit_value *>
+jit_convert::resolve (tree_index_expression& exp)
+{
+  std::string type = exp.type_tags ();
+  if (! (type.size () == 1 && type[0] == '('))
+    fail ("Unsupported index operation");
+
+  std::list<tree_argument_list *> args = exp.arg_lists ();
+  if (args.size () != 1)
+    fail ("Bad number of arguments in tree_index_expression");
+
+  tree_argument_list *arg_list = args.front ();
+  if (! arg_list)
+    fail ("null argument list");
+
+  if (arg_list->size () != 1)
+    fail ("Bad number of arguments in arg_list");
+
+  tree_expression *tree_object = exp.expression ();
+  jit_value *object = visit (tree_object);
+  tree_expression *arg0 = arg_list->front ();
+  jit_value *index = visit (arg0);
+
+  return std::make_pair (object, index);
+}
+
+jit_value *
+jit_convert::do_assign (tree_expression *exp, jit_value *rhs, bool artificial)
+{
+  if (! exp)
+    fail ("NULL lhs in assign");
+
+  if (isa<tree_identifier> (exp))
+    return do_assign (exp->name (), rhs, exp->print_result (), artificial);
+  else if (tree_index_expression *idx
+           = dynamic_cast<tree_index_expression *> (exp))
+    {
+      std::pair<jit_value *, jit_value *> res = resolve (*idx);
+      jit_value *object = res.first;
+      jit_value *index = res.second;
+      jit_call *new_object = create<jit_call> (&jit_typeinfo::paren_subsasgn,
+                                               object, index, rhs);
+      block->append (new_object);
+      do_assign (idx->expression (), new_object, true);
+      create_check (new_object);
+
+      // FIXME: Will not work for values that must be release/grabed
+      return rhs;
+    }
+  else
+    fail ("Unsupported assignment");
+}
+
 jit_value *
 jit_convert::do_assign (const std::string& lhs, jit_value *rhs,
-                        bool print)
+                        bool print, bool artificial)
 {
   jit_variable *var = get_variable (lhs);
-  block->append (create<jit_assign> (var, rhs));
+  jit_assign *assign = block->append (create<jit_assign> (var, rhs));
+
+  if (artificial)
+    assign->mark_artificial ();
 
   if (print)
     {
@@ -2776,6 +2908,9 @@
 {
   assign.stash_llvm (assign.src ()->to_llvm ());
 
+  if (assign.artificial ())
+    return;
+
   jit_value *new_value = assign.src ();
   if (isa<jit_assign_base> (new_value))
     {