changeset 14969:bbeef7b8ea2e

Add support for matrix indexed assignment to JIT * src/pt-jit.cc (octave_jit_subsasgn_impl, jit_convert::resolve): New function. (jit_typeinfo::jit_typeinfo): Add subsasgn implementation in llvm. (jit_convert::visit_simple_for_command): Use new do_assign overload. (jit_convert::visit_index_expression): Use new do_assign overload and resolve. (jit_convert::visit_simple_assignment): Use new do_assign overload. (jit_convert::do_assign): New overload. (jit_convert::convert_llvm::visit): Check if assignment is artificial. * src/pt-jit.h (jit_typeinfo::paren_subsasgn, jit_convert::create_check): New function. (jit_assign::jit_assign): Initialize martificial. (jit_assign::artificial, jit_assign::mark_artificial): New function. (jit_assign::print): Print the artificial flag. (jit_convert::create_checked_impl): Call create_check. (jit_convert::resolve): New declaration. (jit_convert::do_assign): New overload declaration.
author Max Brister <max@2bass.com>
date Mon, 25 Jun 2012 14:21:45 -0500
parents 7f60cdfcc0e5
children 0f156affb036
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 226 insertions(+), 60 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc	Fri Jun 22 17:17:48 2012 -0500
+++ b/src/pt-jit.cc	Mon Jun 25 14:21:45 2012 -0500
@@ -237,6 +237,24 @@
 }
 
 extern "C" void
+octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index,
+                                double value)
+{
+  std::cout << "impl\n";
+  NDArray *array = mat->array;
+  if (array->nelem () < index)
+    array->resize1 (index);
+
+  double *data = array->fortran_vec ();
+  data[index - 1] = value;
+
+  mat->ref_count = array->jit_ref_count ();
+  mat->slice_data = array->jit_slice_data () - 1;
+  mat->dimensions = array->jit_dimensions ();
+  mat->slice_len = array->nelem ();
+}
+
+extern "C" void
 octave_jit_print_matrix (jit_matrix *m)
 {
   std::cout << *m << std::endl;
@@ -755,6 +773,92 @@
   llvm::verifyFunction (*fn);
   paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar);
 
+  // paren subsasgn
+  paren_subsasgn_fn.stash_name ("()subsasgn");
+
+  llvm::Function *resize_paren_subsasgn
+    = create_function ("octave_jit_paren_subsasgn_impl", void_t,
+                       matrix_t->getPointerTo (), index_t, scalar_t);
+  engine->addGlobalMapping (resize_paren_subsasgn,
+                            reinterpret_cast<void *> (&octave_jit_paren_subsasgn_impl));
+
+  fn = create_function ("octave_jit_paren_subsasgn", matrix, matrix, scalar,
+                        scalar);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *one = llvm::ConstantInt::get (index_t, 1);
+
+    llvm::Function::arg_iterator args = fn->arg_begin ();
+    llvm::Value *mat = args++;
+    llvm::Value *idx = args++;
+    llvm::Value *value = args;
+
+    llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t);
+    llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t);
+    llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx);
+    llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one);
+    llvm::Value *cond = builder.CreateOr (cond0, cond1);
+
+    llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn);
+
+    llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context,
+                                                             "conv_error", fn,
+                                                             done);
+    llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn,
+                                                         done);
+    builder.CreateCondBr (cond, conv_error, normal);
+    builder.SetInsertPoint (conv_error);
+    builder.CreateCall (ginvalid_index);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (normal);
+    llvm::Value *len = builder.CreateExtractValue (mat,
+                                                   llvm::ArrayRef<unsigned> (2));
+    cond0 = builder.CreateICmpSGT (int_idx, len);
+
+    llvm::Value *rcount = builder.CreateExtractValue (mat, 0);
+    rcount = builder.CreateLoad (rcount);
+    cond1 = builder.CreateICmpSGT (rcount, one);
+    cond = builder.CreateOr (cond0, cond1);
+
+    llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context,
+                                                               "bounds_error",
+                                                               fn, done);
+
+    llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success",
+                                                          fn, done);
+    builder.CreateCondBr (cond, bounds_error, success);
+
+    // resize on out of bounds access
+    builder.SetInsertPoint (bounds_error);
+    llvm::Value *resize_result = builder.CreateAlloca (matrix_t);
+    builder.CreateStore (mat, resize_result);
+    builder.CreateCall3 (resize_paren_subsasgn, resize_result, int_idx, value);
+    resize_result = builder.CreateLoad (resize_result);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (success);
+    llvm::Value *data = builder.CreateExtractValue (mat,
+                                                    llvm::ArrayRef<unsigned> (1));
+    llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx);
+    builder.CreateStore (value, gep);
+    builder.CreateBr (done);
+
+    builder.SetInsertPoint (done);
+
+    llvm::PHINode *merge = llvm::PHINode::Create (matrix_t, 3);
+    builder.Insert (merge);
+    merge->addIncoming (mat, conv_error);
+    merge->addIncoming (resize_result, bounds_error);
+    merge->addIncoming (mat, success);
+    builder.CreateRet (merge);
+  }
+  llvm::verifyFunction (*fn);
+  paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, scalar, scalar);
+
+  // paren_subsasgn
+
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
 
@@ -1689,12 +1793,6 @@
   prot.protect_var (breaking);
   breaks.clear ();
 
-  // FIXME: one of these days we will introduce proper lvalues...
-  tree_identifier *lhs = dynamic_cast<tree_identifier *>(cmd.left_hand_side ());
-  if (! lhs)
-    fail ();
-  std::string lhs_name = lhs->name ();
-
   // we need a variable for our iterator, because it is used in multiple blocks
   std::stringstream ss;
   ss << "#iter" << iterator_count++;
@@ -1719,9 +1817,10 @@
   block = body;
 
   // compute the syntactical iterator
-  jit_call *idx_rhs = create<jit_call> (jit_typeinfo::for_index, control, iterator);
+  jit_call *idx_rhs = create<jit_call> (jit_typeinfo::for_index, control,
+                                        iterator);
   block->append (idx_rhs);
-  do_assign (lhs_name, idx_rhs, false);
+  do_assign (cmd.left_hand_side (), idx_rhs);
 
   // do loop
   tree_statement_list *pt_body = cmd.body ();
@@ -1901,26 +2000,9 @@
 void
 jit_convert::visit_index_expression (tree_index_expression& exp)
 {
-  std::string type = exp.type_tags ();
-  if (! (type.size () == 1 && type[0] == '('))
-    fail ("Unsupported index operation");
-
-  std::list<tree_argument_list *> args = exp.arg_lists ();
-  if (args.size () != 1)
-    fail ("Bad number of arguments in tree_index_expression");
-
-  tree_argument_list *arg_list = args.front ();
-  if (! arg_list)
-    fail ("null argument list");
-
-  if (arg_list->size () != 1)
-    fail ("Bad number of arguments in arg_list");
-
-  tree_expression *tree_object = exp.expression ();
-  jit_value *object = visit (tree_object);
-
-  tree_expression *arg0 = arg_list->front ();
-  jit_value *index = visit (arg0);
+  std::pair<jit_value *, jit_value *> res = resolve (exp);
+  jit_value *object = res.first;
+  jit_value *index = res.second;
 
   result = create_checked (jit_typeinfo::paren_subsref, object, index);
 }
@@ -2013,13 +2095,7 @@
   tree_expression *rhs = tsa.right_hand_side ();
   jit_value *rhsv = visit (rhs);
 
-  // resolve lhs
-  tree_expression *lhs = tsa.left_hand_side ();
-  if (! lhs->is_identifier ())
-    fail ();
-
-  std::string lhs_name = lhs->name ();
-  result = do_assign (lhs_name, rhsv, tsa.print_result ());
+  do_assign (tsa.left_hand_side (), rhsv);
 }
 
 void
@@ -2156,12 +2232,68 @@
   return vmap[vname] = var;
 }
 
+std::pair<jit_value *, jit_value *>
+jit_convert::resolve (tree_index_expression& exp)
+{
+  std::string type = exp.type_tags ();
+  if (! (type.size () == 1 && type[0] == '('))
+    fail ("Unsupported index operation");
+
+  std::list<tree_argument_list *> args = exp.arg_lists ();
+  if (args.size () != 1)
+    fail ("Bad number of arguments in tree_index_expression");
+
+  tree_argument_list *arg_list = args.front ();
+  if (! arg_list)
+    fail ("null argument list");
+
+  if (arg_list->size () != 1)
+    fail ("Bad number of arguments in arg_list");
+
+  tree_expression *tree_object = exp.expression ();
+  jit_value *object = visit (tree_object);
+  tree_expression *arg0 = arg_list->front ();
+  jit_value *index = visit (arg0);
+
+  return std::make_pair (object, index);
+}
+
+jit_value *
+jit_convert::do_assign (tree_expression *exp, jit_value *rhs, bool artificial)
+{
+  if (! exp)
+    fail ("NULL lhs in assign");
+
+  if (isa<tree_identifier> (exp))
+    return do_assign (exp->name (), rhs, exp->print_result (), artificial);
+  else if (tree_index_expression *idx
+           = dynamic_cast<tree_index_expression *> (exp))
+    {
+      std::pair<jit_value *, jit_value *> res = resolve (*idx);
+      jit_value *object = res.first;
+      jit_value *index = res.second;
+      jit_call *new_object = create<jit_call> (&jit_typeinfo::paren_subsasgn,
+                                               object, index, rhs);
+      block->append (new_object);
+      do_assign (idx->expression (), new_object, true);
+      create_check (new_object);
+
+      // FIXME: Will not work for values that must be release/grabed
+      return rhs;
+    }
+  else
+    fail ("Unsupported assignment");
+}
+
 jit_value *
 jit_convert::do_assign (const std::string& lhs, jit_value *rhs,
-                        bool print)
+                        bool print, bool artificial)
 {
   jit_variable *var = get_variable (lhs);
-  block->append (create<jit_assign> (var, rhs));
+  jit_assign *assign = block->append (create<jit_assign> (var, rhs));
+
+  if (artificial)
+    assign->mark_artificial ();
 
   if (print)
     {
@@ -2776,6 +2908,9 @@
 {
   assign.stash_llvm (assign.src ()->to_llvm ());
 
+  if (assign.artificial ())
+    return;
+
   jit_value *new_value = assign.src ();
   if (isa<jit_assign_base> (new_value))
     {
--- a/src/pt-jit.h	Fri Jun 22 17:17:48 2012 -0500
+++ b/src/pt-jit.h	Mon Jun 25 14:21:45 2012 -0500
@@ -39,34 +39,25 @@
 
 // -------------------- Current status --------------------
 // Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized.
-// However, there is no warning emitted on divide by 0. For example,
 // a = 5;
 // b = a * 5 + a;
 //
-// For other types all binary operations are compiled but not optimized. For
-// example,
-// a = [1 2 3]
-// b = a + a;
-// will compile to do_binary_op (a, a).
+// Indexing matrices with scalars works.
 //
-// For loops are compiled again!
-// if, elseif, and else statements compile again!
-// break and continue now work!
-//
-// NOTE: Matrix access is currently broken!
+// if, elseif, else, break, continue, and for compile. Compilation is triggered
+// at the start of a simple for loop.
 //
 // The octave low level IR is a linear IR, it works by converting everything to
 // calls to jit_functions. This turns expressions like c = a + b into
 // c = call binary+ (a, b)
-// The jit_functions contain information about overloads for differnt types. For
-// example, if we know a and b are scalars, then c must also be a scalar.
+// The jit_functions contain information about overloads for different types.
+// For, example, if we know a and b are scalars, then c must also be a scalar.
 //
 //
 // TODO:
-// 1. Support some simple matrix case (and cleanup Octave low level IR)
-// 2. Function calls
-// 3. Cleanup/documentation
-// 4. ...
+// 1. Function calls
+// 2. Cleanup/documentation
+// 3. ...
 // ---------------------------------------------------------
 
 
@@ -93,6 +84,7 @@
 class octave_base_value;
 class octave_value;
 class tree;
+class tree_expression;
 
 template <typename HOLDER_T, typename SUB_T>
 class jit_internal_node;
@@ -498,6 +490,11 @@
     return instance->paren_subsref_fn;
   }
 
+  static const jit_function& paren_subsasgn (void)
+  {
+    return instance->paren_subsasgn_fn;
+  }
+
   static const jit_function& logically_true (void)
   {
     return instance->logically_true_fn;
@@ -695,6 +692,7 @@
   jit_function logically_true_fn;
   jit_function make_range_fn;
   jit_function paren_subsref_fn;
+  jit_function paren_subsasgn_fn;
 
   // type id -> cast function TO that type
   std::vector<jit_function> casts;
@@ -1557,7 +1555,7 @@
 {
 public:
   jit_assign (jit_variable *adest, jit_value *asrc)
-    : jit_assign_base (adest, adest, asrc) {}
+    : jit_assign_base (adest, adest, asrc), martificial (false) {}
 
   jit_value *overwrite (void) const
   {
@@ -1569,6 +1567,13 @@
     return argument (1);
   }
 
+  // variables don't get modified in an SSA, but COW requires we modify
+  // variables. An artificial assign is for when a variable gets modified. We
+  // need an assign in the SSA, but the reference counts shouldn't be updated.
+  bool artificial (void) const { return martificial; }
+
+  void mark_artificial (void) { martificial = true; }
+
   virtual bool infer (void)
   {
     jit_type *stype = src ()->type ();
@@ -1583,10 +1588,17 @@
 
   virtual std::ostream& print (std::ostream& os, size_t indent = 0) const
   {
-    return print_indent (os, indent) << *this << " = " << *src ();
+    print_indent (os, indent) << *this << " = " << *src ();
+
+    if (artificial ())
+      os << " [artificial]";
+
+    return os;
   }
 
   JIT_VALUE_ACCEPT;
+private:
+  bool martificial;
 };
 
 class
@@ -2150,6 +2162,14 @@
     return create_checked_impl (ret);
   }
 
+  template <typename ARG0, typename ARG1, typename ARG2, typename ARG3>
+  jit_call *create_checked (const ARG0& arg0, const ARG1& arg1,
+                            const ARG2& arg2, const ARG3& arg3)
+  {
+    jit_call *ret = create<jit_call> (arg0, arg1, arg2, arg3);
+    return create_checked_impl (ret);
+  }
+
   typedef std::list<jit_block *> block_list;
   typedef block_list::iterator block_iterator;
 
@@ -2199,9 +2219,15 @@
   jit_call *create_checked_impl (jit_call *ret)
   {
     block->append (ret);
-
+    create_check (ret);
+    return ret;
+  }
+
+  jit_error_check *create_check (jit_call *call)
+  {
     jit_block *normal = create<jit_block> (block->name ());
-    block->append (create<jit_error_check> (ret, normal, final_block));
+    jit_error_check *ret
+      = block->append (create<jit_error_check> (call, normal, final_block));
     append (normal);
     block = normal;
 
@@ -2210,8 +2236,13 @@
 
   jit_variable *get_variable (const std::string& vname);
 
-  jit_value *do_assign (const std::string& lhs, jit_value *rhs, bool print);
-
+  std::pair<jit_value *, jit_value *> resolve (tree_index_expression& exp);
+
+  jit_value *do_assign (tree_expression *exp, jit_value *rhs,
+                        bool artificial = false);
+
+  jit_value *do_assign (const std::string& lhs, jit_value *rhs, bool print,
+                        bool artificial = false);
 
   jit_value *visit (tree *tee) { return visit (*tee); }