# HG changeset patch # User Max Brister # Date 1340652105 18000 # Node ID bbeef7b8ea2ef49e1d5e16cff33874e0cfeef677 # Parent 7f60cdfcc0e50f463a665b341a4b3a3cb20d4850 Add support for matrix indexed assignment to JIT * src/pt-jit.cc (octave_jit_subsasgn_impl, jit_convert::resolve): New function. (jit_typeinfo::jit_typeinfo): Add subsasgn implementation in llvm. (jit_convert::visit_simple_for_command): Use new do_assign overload. (jit_convert::visit_index_expression): Use new do_assign overload and resolve. (jit_convert::visit_simple_assignment): Use new do_assign overload. (jit_convert::do_assign): New overload. (jit_convert::convert_llvm::visit): Check if assignment is artificial. * src/pt-jit.h (jit_typeinfo::paren_subsasgn, jit_convert::create_check): New function. (jit_assign::jit_assign): Initialize martificial. (jit_assign::artificial, jit_assign::mark_artificial): New function. (jit_assign::print): Print the artificial flag. (jit_convert::create_checked_impl): Call create_check. (jit_convert::resolve): New declaration. (jit_convert::do_assign): New overload declaration. diff -r 7f60cdfcc0e5 -r bbeef7b8ea2e src/pt-jit.cc --- a/src/pt-jit.cc Fri Jun 22 17:17:48 2012 -0500 +++ b/src/pt-jit.cc Mon Jun 25 14:21:45 2012 -0500 @@ -237,6 +237,24 @@ } extern "C" void +octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, + double value) +{ + std::cout << "impl\n"; + NDArray *array = mat->array; + if (array->nelem () < index) + array->resize1 (index); + + double *data = array->fortran_vec (); + data[index - 1] = value; + + mat->ref_count = array->jit_ref_count (); + mat->slice_data = array->jit_slice_data () - 1; + mat->dimensions = array->jit_dimensions (); + mat->slice_len = array->nelem (); +} + +extern "C" void octave_jit_print_matrix (jit_matrix *m) { std::cout << *m << std::endl; @@ -755,6 +773,92 @@ llvm::verifyFunction (*fn); paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar); + // paren subsasgn + paren_subsasgn_fn.stash_name ("()subsasgn"); + + llvm::Function *resize_paren_subsasgn + = create_function ("octave_jit_paren_subsasgn_impl", void_t, + matrix_t->getPointerTo (), index_t, scalar_t); + engine->addGlobalMapping (resize_paren_subsasgn, + reinterpret_cast (&octave_jit_paren_subsasgn_impl)); + + fn = create_function ("octave_jit_paren_subsasgn", matrix, matrix, scalar, + scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + + llvm::Function::arg_iterator args = fn->arg_begin (); + llvm::Value *mat = args++; + llvm::Value *idx = args++; + llvm::Value *value = args; + + llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); + llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); + llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); + llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); + llvm::Value *cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn); + + llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context, + "conv_error", fn, + done); + llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn, + done); + builder.CreateCondBr (cond, conv_error, normal); + builder.SetInsertPoint (conv_error); + builder.CreateCall (ginvalid_index); + builder.CreateBr (done); + + builder.SetInsertPoint (normal); + llvm::Value *len = builder.CreateExtractValue (mat, + llvm::ArrayRef (2)); + cond0 = builder.CreateICmpSGT (int_idx, len); + + llvm::Value *rcount = builder.CreateExtractValue (mat, 0); + rcount = builder.CreateLoad (rcount); + cond1 = builder.CreateICmpSGT (rcount, one); + cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context, + "bounds_error", + fn, done); + + llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success", + fn, done); + builder.CreateCondBr (cond, bounds_error, success); + + // resize on out of bounds access + builder.SetInsertPoint (bounds_error); + llvm::Value *resize_result = builder.CreateAlloca (matrix_t); + builder.CreateStore (mat, resize_result); + builder.CreateCall3 (resize_paren_subsasgn, resize_result, int_idx, value); + resize_result = builder.CreateLoad (resize_result); + builder.CreateBr (done); + + builder.SetInsertPoint (success); + llvm::Value *data = builder.CreateExtractValue (mat, + llvm::ArrayRef (1)); + llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); + builder.CreateStore (value, gep); + builder.CreateBr (done); + + builder.SetInsertPoint (done); + + llvm::PHINode *merge = llvm::PHINode::Create (matrix_t, 3); + builder.Insert (merge); + merge->addIncoming (mat, conv_error); + merge->addIncoming (resize_result, bounds_error); + merge->addIncoming (mat, success); + builder.CreateRet (merge); + } + llvm::verifyFunction (*fn); + paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, scalar, scalar); + + // paren_subsasgn + casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); @@ -1689,12 +1793,6 @@ prot.protect_var (breaking); breaks.clear (); - // FIXME: one of these days we will introduce proper lvalues... - tree_identifier *lhs = dynamic_cast(cmd.left_hand_side ()); - if (! lhs) - fail (); - std::string lhs_name = lhs->name (); - // we need a variable for our iterator, because it is used in multiple blocks std::stringstream ss; ss << "#iter" << iterator_count++; @@ -1719,9 +1817,10 @@ block = body; // compute the syntactical iterator - jit_call *idx_rhs = create (jit_typeinfo::for_index, control, iterator); + jit_call *idx_rhs = create (jit_typeinfo::for_index, control, + iterator); block->append (idx_rhs); - do_assign (lhs_name, idx_rhs, false); + do_assign (cmd.left_hand_side (), idx_rhs); // do loop tree_statement_list *pt_body = cmd.body (); @@ -1901,26 +2000,9 @@ void jit_convert::visit_index_expression (tree_index_expression& exp) { - std::string type = exp.type_tags (); - if (! (type.size () == 1 && type[0] == '(')) - fail ("Unsupported index operation"); - - std::list args = exp.arg_lists (); - if (args.size () != 1) - fail ("Bad number of arguments in tree_index_expression"); - - tree_argument_list *arg_list = args.front (); - if (! arg_list) - fail ("null argument list"); - - if (arg_list->size () != 1) - fail ("Bad number of arguments in arg_list"); - - tree_expression *tree_object = exp.expression (); - jit_value *object = visit (tree_object); - - tree_expression *arg0 = arg_list->front (); - jit_value *index = visit (arg0); + std::pair res = resolve (exp); + jit_value *object = res.first; + jit_value *index = res.second; result = create_checked (jit_typeinfo::paren_subsref, object, index); } @@ -2013,13 +2095,7 @@ tree_expression *rhs = tsa.right_hand_side (); jit_value *rhsv = visit (rhs); - // resolve lhs - tree_expression *lhs = tsa.left_hand_side (); - if (! lhs->is_identifier ()) - fail (); - - std::string lhs_name = lhs->name (); - result = do_assign (lhs_name, rhsv, tsa.print_result ()); + do_assign (tsa.left_hand_side (), rhsv); } void @@ -2156,12 +2232,68 @@ return vmap[vname] = var; } +std::pair +jit_convert::resolve (tree_index_expression& exp) +{ + std::string type = exp.type_tags (); + if (! (type.size () == 1 && type[0] == '(')) + fail ("Unsupported index operation"); + + std::list args = exp.arg_lists (); + if (args.size () != 1) + fail ("Bad number of arguments in tree_index_expression"); + + tree_argument_list *arg_list = args.front (); + if (! arg_list) + fail ("null argument list"); + + if (arg_list->size () != 1) + fail ("Bad number of arguments in arg_list"); + + tree_expression *tree_object = exp.expression (); + jit_value *object = visit (tree_object); + tree_expression *arg0 = arg_list->front (); + jit_value *index = visit (arg0); + + return std::make_pair (object, index); +} + +jit_value * +jit_convert::do_assign (tree_expression *exp, jit_value *rhs, bool artificial) +{ + if (! exp) + fail ("NULL lhs in assign"); + + if (isa (exp)) + return do_assign (exp->name (), rhs, exp->print_result (), artificial); + else if (tree_index_expression *idx + = dynamic_cast (exp)) + { + std::pair res = resolve (*idx); + jit_value *object = res.first; + jit_value *index = res.second; + jit_call *new_object = create (&jit_typeinfo::paren_subsasgn, + object, index, rhs); + block->append (new_object); + do_assign (idx->expression (), new_object, true); + create_check (new_object); + + // FIXME: Will not work for values that must be release/grabed + return rhs; + } + else + fail ("Unsupported assignment"); +} + jit_value * jit_convert::do_assign (const std::string& lhs, jit_value *rhs, - bool print) + bool print, bool artificial) { jit_variable *var = get_variable (lhs); - block->append (create (var, rhs)); + jit_assign *assign = block->append (create (var, rhs)); + + if (artificial) + assign->mark_artificial (); if (print) { @@ -2776,6 +2908,9 @@ { assign.stash_llvm (assign.src ()->to_llvm ()); + if (assign.artificial ()) + return; + jit_value *new_value = assign.src (); if (isa (new_value)) { diff -r 7f60cdfcc0e5 -r bbeef7b8ea2e src/pt-jit.h --- a/src/pt-jit.h Fri Jun 22 17:17:48 2012 -0500 +++ b/src/pt-jit.h Mon Jun 25 14:21:45 2012 -0500 @@ -39,34 +39,25 @@ // -------------------- Current status -------------------- // Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized. -// However, there is no warning emitted on divide by 0. For example, // a = 5; // b = a * 5 + a; // -// For other types all binary operations are compiled but not optimized. For -// example, -// a = [1 2 3] -// b = a + a; -// will compile to do_binary_op (a, a). +// Indexing matrices with scalars works. // -// For loops are compiled again! -// if, elseif, and else statements compile again! -// break and continue now work! -// -// NOTE: Matrix access is currently broken! +// if, elseif, else, break, continue, and for compile. Compilation is triggered +// at the start of a simple for loop. // // The octave low level IR is a linear IR, it works by converting everything to // calls to jit_functions. This turns expressions like c = a + b into // c = call binary+ (a, b) -// The jit_functions contain information about overloads for differnt types. For -// example, if we know a and b are scalars, then c must also be a scalar. +// The jit_functions contain information about overloads for different types. +// For, example, if we know a and b are scalars, then c must also be a scalar. // // // TODO: -// 1. Support some simple matrix case (and cleanup Octave low level IR) -// 2. Function calls -// 3. Cleanup/documentation -// 4. ... +// 1. Function calls +// 2. Cleanup/documentation +// 3. ... // --------------------------------------------------------- @@ -93,6 +84,7 @@ class octave_base_value; class octave_value; class tree; +class tree_expression; template class jit_internal_node; @@ -498,6 +490,11 @@ return instance->paren_subsref_fn; } + static const jit_function& paren_subsasgn (void) + { + return instance->paren_subsasgn_fn; + } + static const jit_function& logically_true (void) { return instance->logically_true_fn; @@ -695,6 +692,7 @@ jit_function logically_true_fn; jit_function make_range_fn; jit_function paren_subsref_fn; + jit_function paren_subsasgn_fn; // type id -> cast function TO that type std::vector casts; @@ -1557,7 +1555,7 @@ { public: jit_assign (jit_variable *adest, jit_value *asrc) - : jit_assign_base (adest, adest, asrc) {} + : jit_assign_base (adest, adest, asrc), martificial (false) {} jit_value *overwrite (void) const { @@ -1569,6 +1567,13 @@ return argument (1); } + // variables don't get modified in an SSA, but COW requires we modify + // variables. An artificial assign is for when a variable gets modified. We + // need an assign in the SSA, but the reference counts shouldn't be updated. + bool artificial (void) const { return martificial; } + + void mark_artificial (void) { martificial = true; } + virtual bool infer (void) { jit_type *stype = src ()->type (); @@ -1583,10 +1588,17 @@ virtual std::ostream& print (std::ostream& os, size_t indent = 0) const { - return print_indent (os, indent) << *this << " = " << *src (); + print_indent (os, indent) << *this << " = " << *src (); + + if (artificial ()) + os << " [artificial]"; + + return os; } JIT_VALUE_ACCEPT; +private: + bool martificial; }; class @@ -2150,6 +2162,14 @@ return create_checked_impl (ret); } + template + jit_call *create_checked (const ARG0& arg0, const ARG1& arg1, + const ARG2& arg2, const ARG3& arg3) + { + jit_call *ret = create (arg0, arg1, arg2, arg3); + return create_checked_impl (ret); + } + typedef std::list block_list; typedef block_list::iterator block_iterator; @@ -2199,9 +2219,15 @@ jit_call *create_checked_impl (jit_call *ret) { block->append (ret); - + create_check (ret); + return ret; + } + + jit_error_check *create_check (jit_call *call) + { jit_block *normal = create (block->name ()); - block->append (create (ret, normal, final_block)); + jit_error_check *ret + = block->append (create (call, normal, final_block)); append (normal); block = normal; @@ -2210,8 +2236,13 @@ jit_variable *get_variable (const std::string& vname); - jit_value *do_assign (const std::string& lhs, jit_value *rhs, bool print); - + std::pair resolve (tree_index_expression& exp); + + jit_value *do_assign (tree_expression *exp, jit_value *rhs, + bool artificial = false); + + jit_value *do_assign (const std::string& lhs, jit_value *rhs, bool print, + bool artificial = false); jit_value *visit (tree *tee) { return visit (*tee); }