Mercurial > octave-nkf
diff src/pt-jit.cc @ 14917:232d8ab07932
Rewrite pt-jit.* adding new low level octave IR
* src/pt-eval.cc (tree_evaluator::visit_simple_for_command): Remove jit
(tree_evaluator::visit_statement): Add jit
* src/pt-jit.h: Rewrite
* src/pt-jit.cc: Rewrite
author | Max Brister <max@2bass.com> |
---|---|
date | Thu, 24 May 2012 15:08:09 -0600 |
parents | cba58541954c |
children | 13465aab507f |
line wrap: on
line diff
--- a/src/pt-jit.cc Wed May 23 16:22:05 2012 -0400 +++ b/src/pt-jit.cc Thu May 24 15:08:09 2012 -0600 @@ -52,6 +52,7 @@ #include "octave.h" #include "ov-fcn-handle.h" #include "ov-usr-fcn.h" +#include "ov-scalar.h" #include "pt-all.h" // FIXME: Remove eventually @@ -60,6 +61,10 @@ static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); +static llvm::LLVMContext& context = llvm::getGlobalContext (); + +jit_typeinfo *jit_typeinfo::instance; + // thrown when we should give up on JIT and interpret class jit_fail_exception : public std::exception {}; @@ -102,10 +107,25 @@ obv->release (); } -extern "C" void +extern "C" octave_base_value * octave_jit_grab_any (octave_base_value *obv) { obv->grab (); + return obv; +} + +extern "C" double +octave_jit_cast_scalar_any (octave_base_value *obv) +{ + double ret = obv->double_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_scalar (double value) +{ + return new octave_scalar (value); } // -------------------- jit_type -------------------- @@ -155,6 +175,10 @@ if (types.size () >= overloads.size ()) return null_overload; + for (size_t i =0; i < types.size (); ++i) + if (! types[i]) + return null_overload; + const Array<overload>& over = overloads[types.size ()]; dim_vector dv (over.dims ()); Array<octave_idx_type> idx = to_idx (types); @@ -187,46 +211,56 @@ } // -------------------- jit_typeinfo -------------------- +void +jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) +{ + instance = new jit_typeinfo (m, e); +} + jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) : module (m), engine (e), next_id (0) { // FIXME: We should be registering types like in octave_value_typeinfo - llvm::LLVMContext &ctx = m->getContext (); - - ov_t = llvm::StructType::create (ctx, "octave_base_value"); + ov_t = llvm::StructType::create (context, "octave_base_value"); ov_t = ov_t->getPointerTo (); - llvm::Type *dbl = llvm::Type::getDoubleTy (ctx); - llvm::Type *bool_t = llvm::Type::getInt1Ty (ctx); + llvm::Type *dbl = llvm::Type::getDoubleTy (context); + llvm::Type *bool_t = llvm::Type::getInt1Ty (context); + llvm::Type *string_t = llvm::Type::getInt8Ty (context); + string_t = string_t->getPointerTo (); llvm::Type *index_t = 0; switch (sizeof(octave_idx_type)) { case 4: - index_t = llvm::Type::getInt32Ty (ctx); + index_t = llvm::Type::getInt32Ty (context); break; case 8: - index_t = llvm::Type::getInt64Ty (ctx); + index_t = llvm::Type::getInt64Ty (context); break; default: assert (false && "Unrecognized index type size"); } - llvm::StructType *range_t = llvm::StructType::create (ctx, "range"); + llvm::StructType *range_t = llvm::StructType::create (context, "range"); std::vector<llvm::Type *> range_contents (4, dbl); range_contents[3] = index_t; range_t->setBody (range_contents); // create types - any = new_type ("any", true, 0, ov_t); - scalar = new_type ("scalar", false, any, dbl); - range = new_type ("range", false, any, range_t); - boolean = new_type ("bool", false, any, bool_t); - index = new_type ("index", false, any, index_t); + any = new_type ("any", 0, ov_t); + scalar = new_type ("scalar", any, dbl); + range = new_type ("range", any, range_t); + string = new_type ("string", any, string_t); + boolean = new_type ("bool", any, bool_t); + index = new_type ("index", any, index_t); + + casts.resize (next_id + 1); + identities.resize (next_id + 1, 0); // any with anything is an any op llvm::Function *fn; llvm::Type *binary_op_type - = llvm::Type::getIntNTy (ctx, sizeof (octave_value::binary_op)); + = llvm::Type::getIntNTy (context, sizeof (octave_value::binary_op)); llvm::Function *any_binary = create_function ("octave_jit_binary_any_any", any->to_llvm (), binary_op_type, any->to_llvm (), any->to_llvm ()); @@ -234,12 +268,19 @@ reinterpret_cast<void*>(&octave_jit_binary_any_any)); binary_ops.resize (octave_value::num_binary_ops); + for (size_t i = 0; i < octave_value::num_binary_ops; ++i) + { + octave_value::binary_op op = static_cast<octave_value::binary_op> (i); + std::string op_name = octave_value::binary_op_as_string (op); + binary_ops[i].stash_name ("binary" + op_name); + } + for (int op = 0; op < octave_value::num_binary_ops; ++op) { llvm::Twine fn_name ("octave_jit_binary_any_any_"); fn_name = fn_name + llvm::Twine (op); fn = create_function (fn_name, any, any, any); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::APInt op_int(sizeof (octave_value::binary_op), op, std::numeric_limits<octave_value::binary_op>::is_signed); @@ -255,18 +296,28 @@ binary_ops[op].add_overload (overload); } - llvm::Type *void_t = llvm::Type::getVoidTy (ctx); + llvm::Type *void_t = llvm::Type::getVoidTy (context); // grab any - fn = create_function ("octave_jit_grab_any", void_t, any->to_llvm ()); + fn = create_function ("octave_jit_grab_any", any, any); engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_grab_any)); - grab_fn.add_overload (fn, false, 0, any); + grab_fn.add_overload (fn, false, any, any); + grab_fn.stash_name ("grab"); + + // grab scalar + fn = create_identity (scalar); + grab_fn.add_overload (fn, false, scalar, scalar); // release any fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any)); release_fn.add_overload (fn, false, 0, any); + release_fn.stash_name ("release"); + + // release scalar + fn = create_identity (scalar); + release_fn.add_overload (fn, false, 0, scalar); // now for binary scalar operations // FIXME: Finish all operations @@ -287,12 +338,13 @@ add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); // now for printing functions + print_fn.stash_name ("print"); add_print (any, reinterpret_cast<void*> (&octave_jit_print_any)); add_print (scalar, reinterpret_cast<void*> (&octave_jit_print_double)); // bounds check for for loop fn = create_function ("octave_jit_simple_for_range", boolean, range, index); - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *nelem @@ -307,7 +359,7 @@ // increment for for loop fn = create_function ("octave_jit_imple_for_range_incr", index, index); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *one = llvm::ConstantInt::get (index_t, 1); @@ -320,7 +372,7 @@ // index variabe for for loop fn = create_function ("octave_jit_simple_for_idx", scalar, range, index); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *idx = ++fn->arg_begin (); @@ -339,7 +391,7 @@ // logically true // FIXME: Check for NaN fn = create_function ("octave_logically_true_scalar", boolean, scalar); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *zero = llvm::ConstantFP::get (scalar->to_llvm (), 0); @@ -350,11 +402,33 @@ logically_true.add_overload (fn, true, boolean, scalar); fn = create_function ("octave_logically_true_bool", boolean, boolean); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); builder.CreateRet (fn->arg_begin ()); llvm::verifyFunction (*fn); logically_true.add_overload (fn, false, boolean, boolean); + logically_true.stash_name ("logically_true"); + + casts[any->type_id ()].stash_name ("(any)"); + casts[scalar->type_id ()].stash_name ("(scalar)"); + + // cast any <- scalar + fn = create_function ("octave_jit_cast_any_scalar", any, scalar); + engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_scalar)); + casts[any->type_id ()].add_overload (fn, false, any, scalar); + + // cast scalar <- any + fn = create_function ("octave_jit_cast_scalar_any", scalar, any); + engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any)); + casts[scalar->type_id ()].add_overload (fn, false, scalar, any); + + // cast any <- any + fn = create_identity (any); + casts[any->type_id ()].add_overload (fn, false, any, any); + + // cast scalar <- scalar + fn = create_identity (scalar); + casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar); } void @@ -363,14 +437,13 @@ std::stringstream name; name << "octave_jit_print_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::Type *void_t = llvm::Type::getVoidTy (ctx); + llvm::Type *void_t = llvm::Type::getVoidTy (context); llvm::Function *fn = create_function (name.str (), void_t, - llvm::Type::getInt8PtrTy (ctx), + llvm::Type::getInt8PtrTy (context), ty->to_llvm ()); engine->addGlobalMapping (fn, call); - jit_function::overload ol (fn, false, 0, ty); + jit_function::overload ol (fn, false, 0, string, ty); print_fn.add_overload (ol); } @@ -383,9 +456,8 @@ fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); llvm::Function *fn = create_function (fname.str (), ty, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::Instruction::BinaryOps temp = static_cast<llvm::Instruction::BinaryOps>(llvm_op); @@ -406,9 +478,8 @@ fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::CmpInst::Predicate temp = static_cast<llvm::CmpInst::Predicate>(llvm_op); @@ -429,9 +500,8 @@ fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::CmpInst::Predicate temp = static_cast<llvm::CmpInst::Predicate>(llvm_op); @@ -454,10 +524,30 @@ name, module); fn->addFnAttr (llvm::Attribute::AlwaysInline); return fn; -} +} + +llvm::Function * +jit_typeinfo::create_identity (jit_type *type) +{ + size_t id = type->type_id (); + if (id >= identities.size ()) + identities.resize (id + 1, 0); -jit_type* -jit_typeinfo::type_of (const octave_value &ov) const + if (! identities[id]) + { + llvm::Function *fn = create_function ("id", type, type); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + builder.CreateRet (fn->arg_begin ()); + llvm::verifyFunction (*fn); + identities[id] = fn; + } + + return identities[id]; +} + +jit_type * +jit_typeinfo::do_type_of (const octave_value &ov) const { if (ov.is_undefined () || ov.is_function ()) return 0; @@ -471,34 +561,21 @@ return get_any (); } -const jit_function& -jit_typeinfo::binary_op (int op) const -{ - assert (static_cast<size_t>(op) < binary_ops.size ()); - return binary_ops[op]; -} - -const jit_function::overload& -jit_typeinfo::print_value (jit_type *to_print) const -{ - return print_fn.get_overload (to_print); -} - void -jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv) +jit_typeinfo::do_to_generic (jit_type *type, llvm::GenericValue& gv) { if (type == any) - to_generic (type, gv, octave_value ()); + do_to_generic (type, gv, octave_value ()); else if (type == scalar) - to_generic (type, gv, octave_value (0)); + do_to_generic (type, gv, octave_value (0)); else if (type == range) - to_generic (type, gv, octave_value (Range ())); + do_to_generic (type, gv, octave_value (Range ())); else assert (false && "Type not supported yet"); } void -jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov) +jit_typeinfo::do_to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov) { if (type == any) { @@ -522,7 +599,7 @@ } octave_value -jit_typeinfo::to_octave_value (jit_type *type, llvm::GenericValue& gv) +jit_typeinfo::do_to_octave_value (jit_type *type, llvm::GenericValue& gv) { if (type == any) { @@ -545,7 +622,7 @@ } void -jit_typeinfo::reset_generic (void) +jit_typeinfo::do_reset_generic (void) { scalar_out.clear (); ov_out.clear (); @@ -553,926 +630,373 @@ } jit_type* -jit_typeinfo::new_type (const std::string& name, bool force_init, - jit_type *parent, llvm::Type *llvm_type) +jit_typeinfo::new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type) { - jit_type *ret = new jit_type (name, force_init, parent, llvm_type, next_id++); + jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); id_to_type.push_back (ret); return ret; } -// -------------------- jit_infer -------------------- -void -jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds) -{ - infer_simple_for (cmd, bounds); -} - -void -jit_infer::visit_anon_fcn_handle (tree_anon_fcn_handle&) -{ - fail (); -} - -void -jit_infer::visit_argument_list (tree_argument_list&) -{ - fail (); -} - -void -jit_infer::visit_binary_expression (tree_binary_expression& be) +// -------------------- jit_block -------------------- +llvm::BasicBlock * +jit_block::to_llvm (void) const { - if (is_lvalue) - fail (); - - if (be.op_type () >= octave_value::num_binary_ops) - fail (); - - tree_expression *lhs = be.lhs (); - lhs->accept (*this); - jit_type *tlhs = type_stack.back (); - type_stack.pop_back (); - - tree_expression *rhs = be.rhs (); - rhs->accept (*this); - jit_type *trhs = type_stack.back (); - - jit_type *result = tinfo->binary_op_result (be.op_type (), tlhs, trhs); - if (! result) - fail (); - - type_stack.push_back (result); -} - -void -jit_infer::visit_break_command (tree_break_command&) -{ - fail (); + return llvm::cast<llvm::BasicBlock> (llvm_value); } -void -jit_infer::visit_colon_expression (tree_colon_expression&) -{ - fail (); -} - -void -jit_infer::visit_continue_command (tree_continue_command&) -{ - fail (); -} - -void -jit_infer::visit_global_command (tree_global_command&) -{ - fail (); -} - -void -jit_infer::visit_persistent_command (tree_persistent_command&) -{ - fail (); -} - -void -jit_infer::visit_decl_elt (tree_decl_elt&) +// -------------------- jit_call -------------------- +bool +jit_call::infer (void) { - fail (); -} - -void -jit_infer::visit_decl_init_list (tree_decl_init_list&) -{ - fail (); -} - -void -jit_infer::visit_simple_for_command (tree_simple_for_command& cmd) -{ - tree_expression *control = cmd.control_expr (); - control->accept (*this); - - jit_type *control_t = type_stack.back (); - type_stack.pop_back (); - - // FIXME: We should improve type inference so we don't have to do this - // to generate nested for loop code - - // quick hack, check if the for loop bounds are const. If we - // run at least one, we don't have to merge types - bool atleast_once = false; - if (control->is_constant ()) + // FIXME explain algorithm + jit_type *current = type (); + for (size_t i = 0; i < argument_count (); ++i) { - octave_value over = control->rvalue1 (); - if (over.is_range ()) + jit_type *arg_type = argument_type (i); + jit_type *todo = jit_typeinfo::difference (arg_type, already_infered[i]); + if (todo) { - Range rng = over.range_value (); - atleast_once = rng.nelem () > 0; + already_infered[i] = todo; + jit_type *fresult = mfunction.get_result (already_infered); + current = jit_typeinfo::tunion (current, fresult); + already_infered[i] = arg_type; } } - if (atleast_once) - infer_simple_for (cmd, control_t); - else + if (current != type ()) { - type_map fallthrough = types; - infer_simple_for (cmd, control_t); - merge (types, fallthrough); + stash_type (current); + return true; + } + + return false; +} + +// -------------------- jit_convert -------------------- +jit_convert::jit_convert (llvm::Module *module, tree &tee) +{ + jit_instruction::reset_ids (); + + entry_block = new jit_block ("entry"); + blocks.push_back (entry_block); + block = new jit_block ("body"); + blocks.push_back (block); + + final_block = new jit_block ("final"); + visit (tee); + blocks.push_back (final_block); + + entry_block->append (new jit_break (block)); + block->append (new jit_break (final_block)); + + for (variable_map::iterator iter = variables.begin (); + iter != variables.end (); ++iter) + final_block->append (new jit_store_argument (iter->first, iter->second)); + + // FIXME: Maybe we should remove dead code here? + + // initialize the worklist to instructions derived from constants + for (std::list<jit_value *>::iterator iter = constants.begin (); + iter != constants.end (); ++iter) + append_users (*iter); + + // FIXME: Describe algorithm here + while (worklist.size ()) + { + jit_instruction *next = worklist.front (); + worklist.pop_front (); + + if (next->infer ()) + append_users (next); + } + + if (debug_print) + { + std::cout << "-------------------- Compiling tree --------------------\n"; + std::cout << tee.str_print_code () << std::endl; + std::cout << "-------------------- octave jit ir --------------------\n"; + for (std::list<jit_block *>::iterator iter = blocks.begin (); + iter != blocks.end (); ++iter) + (*iter)->print (std::cout, 0); + std::cout << std::endl; + } + + convert_llvm to_llvm; + function = to_llvm.convert (module, arguments, blocks, constants); + + if (debug_print) + { + std::cout << "-------------------- llvm ir --------------------"; + llvm::raw_os_ostream llvm_cout (std::cout); + function->print (llvm_cout); + std::cout << std::endl; } } void -jit_infer::visit_complex_for_command (tree_complex_for_command&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_script (octave_user_script&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_function (octave_user_function&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_function_header (octave_user_function&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_function_trailer (octave_user_function&) -{ - fail (); -} - -void -jit_infer::visit_function_def (tree_function_def&) -{ - fail (); -} - -void -jit_infer::visit_identifier (tree_identifier& ti) -{ - symbol_table::symbol_record_ref record = ti.symbol (); - handle_identifier (record); -} - -void -jit_infer::visit_if_clause (tree_if_clause&) +jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&) { fail (); } void -jit_infer::visit_if_command (tree_if_command& cmd) -{ - if (is_lvalue) - fail (); - - tree_if_command_list *lst = cmd.cmd_list (); - assert (lst); - lst->accept (*this); -} - -void -jit_infer::visit_if_command_list (tree_if_command_list& lst) -{ - // determine the types on each branch of the if seperatly, then merge - type_map fallthrough = types, last; - bool first_time = true; - for (tree_if_command_list::iterator p = lst.begin (); p != lst.end(); ++p) - { - tree_if_clause *tic = *p; - - if (! first_time) - types = fallthrough; - - if (! tic->is_else_clause ()) - { - tree_expression *expr = tic->condition (); - expr->accept (*this); - } - - fallthrough = types; - - tree_statement_list *stmt_lst = tic->commands (); - assert (stmt_lst); - stmt_lst->accept (*this); - - if (first_time) - last = types; - else - merge (last, types); - } - - types = last; - - tree_if_clause *last_clause = lst.back (); - if (! last_clause->is_else_clause ()) - merge (types, fallthrough); -} - -void -jit_infer::visit_index_expression (tree_index_expression&) +jit_convert::visit_argument_list (tree_argument_list&) { fail (); } void -jit_infer::visit_matrix (tree_matrix&) +jit_convert::visit_binary_expression (tree_binary_expression& be) { - fail (); -} - -void -jit_infer::visit_cell (tree_cell&) -{ - fail (); -} + if (be.op_type () >= octave_value::num_binary_ops) + // this is the case for bool_or and bool_and + fail (); -void -jit_infer::visit_multi_assignment (tree_multi_assignment&) -{ - fail (); -} + tree_expression *lhs = be.lhs (); + jit_value *lhsv = visit (lhs); -void -jit_infer::visit_no_op_command (tree_no_op_command&) -{ - fail (); + tree_expression *rhs = be.rhs (); + jit_value *rhsv = visit (rhs); + + const jit_function& fn = jit_typeinfo::binary_op (be.op_type ()); + result = block->append (new jit_call (fn, lhsv, rhsv)); } void -jit_infer::visit_constant (tree_constant& tc) -{ - if (is_lvalue) - fail (); - - octave_value v = tc.rvalue1 (); - jit_type *type = tinfo->type_of (v); - if (! type) - fail (); - - type_stack.push_back (type); -} - -void -jit_infer::visit_fcn_handle (tree_fcn_handle&) -{ - fail (); -} - -void -jit_infer::visit_parameter_list (tree_parameter_list&) -{ - fail (); -} - -void -jit_infer::visit_postfix_expression (tree_postfix_expression&) -{ - fail (); -} - -void -jit_infer::visit_prefix_expression (tree_prefix_expression&) -{ - fail (); -} - -void -jit_infer::visit_return_command (tree_return_command&) +jit_convert::visit_break_command (tree_break_command&) { fail (); } void -jit_infer::visit_return_list (tree_return_list&) +jit_convert::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +jit_convert::visit_continue_command (tree_continue_command&) +{ + fail (); +} + +void +jit_convert::visit_global_command (tree_global_command&) { fail (); } void -jit_infer::visit_simple_assignment (tree_simple_assignment& tsa) -{ - if (is_lvalue) - fail (); - - // resolve rhs - is_lvalue = false; - tree_expression *rhs = tsa.right_hand_side (); - rhs->accept (*this); - - jit_type *trhs = type_stack.back (); - type_stack.pop_back (); - - // resolve lhs - is_lvalue = true; - rvalue_type = trhs; - tree_expression *lhs = tsa.left_hand_side (); - lhs->accept (*this); - - // we don't pop back here, as the resulting type should be the rhs type - // which is equal to the lhs type anways - jit_type *tlhs = type_stack.back (); - if (tlhs != trhs) - fail (); - - is_lvalue = false; - rvalue_type = 0; -} - -void -jit_infer::visit_statement (tree_statement& stmt) -{ - if (is_lvalue) - fail (); - - tree_command *cmd = stmt.command (); - tree_expression *expr = stmt.expression (); - - if (cmd) - cmd->accept (*this); - else - { - // ok, this check for ans appears three times as cp - bool do_bind_ans = false; - - if (expr->is_identifier ()) - { - tree_identifier *id = dynamic_cast<tree_identifier *> (expr); - - do_bind_ans = (! id->is_variable ()); - } - else - do_bind_ans = (! expr->is_assignment_expression ()); - - expr->accept (*this); - - if (do_bind_ans) - { - is_lvalue = true; - rvalue_type = type_stack.back (); - type_stack.pop_back (); - - symbol_table::symbol_record_ref record (symbol_table::insert ("ans")); - handle_identifier (record); - - if (rvalue_type != type_stack.back ()) - fail (); - - is_lvalue = false; - rvalue_type = 0; - } - - type_stack.pop_back (); - } -} - -void -jit_infer::visit_statement_list (tree_statement_list& lst) -{ - tree_statement_list::iterator iter; - for (iter = lst.begin (); iter != lst.end (); ++iter) - { - tree_statement *stmt = *iter; - assert (stmt); // FIXME: jwe can this be null? - stmt->accept (*this); - } -} - -void -jit_infer::visit_switch_case (tree_switch_case&) +jit_convert::visit_persistent_command (tree_persistent_command&) { fail (); } void -jit_infer::visit_switch_case_list (tree_switch_case_list&) -{ - fail (); -} - -void -jit_infer::visit_switch_command (tree_switch_command&) +jit_convert::visit_decl_elt (tree_decl_elt&) { fail (); } void -jit_infer::visit_try_catch_command (tree_try_catch_command&) +jit_convert::visit_decl_init_list (tree_decl_init_list&) { fail (); } void -jit_infer::visit_unwind_protect_command (tree_unwind_protect_command&) -{ - fail (); -} - -void -jit_infer::visit_while_command (tree_while_command&) -{ - fail (); -} - -void -jit_infer::visit_do_until_command (tree_do_until_command&) +jit_convert::visit_simple_for_command (tree_simple_for_command&) { fail (); } void -jit_infer::infer_simple_for (tree_simple_for_command& cmd, - jit_type *bounds) +jit_convert::visit_complex_for_command (tree_complex_for_command&) { - if (is_lvalue) - fail (); - - jit_type *iter = tinfo->get_simple_for_index_result (bounds); - if (! iter) - fail (); - - is_lvalue = true; - rvalue_type = iter; - tree_expression *lhs = cmd.left_hand_side (); - lhs->accept (*this); - if (type_stack.back () != iter) - fail (); - type_stack.pop_back (); - is_lvalue = false; - rvalue_type = 0; - - tree_statement_list *body = cmd.body (); - body->accept (*this); + fail (); } void -jit_infer::handle_identifier (const symbol_table::symbol_record_ref& record) +jit_convert::visit_octave_user_script (octave_user_script&) { - type_map::iterator iter = types.find (record); - if (iter == types.end ()) - { - jit_type *ty = tinfo->type_of (record->find ()); - bool argin = false; - if (is_lvalue) - { - if (! ty) - ty = rvalue_type; - } - else - { - if (! ty) - fail (); - argin = true; - } - - types[record] = type_entry (argin, ty); - type_stack.push_back (ty); - } - else - type_stack.push_back (iter->second.second); + fail (); } void -jit_infer::merge (type_map& dest, const type_map& src) -{ - if (dest.size () != src.size ()) - fail (); - - type_map::iterator dest_iter; - type_map::const_iterator src_iter; - for (dest_iter = dest.begin (), src_iter = src.begin (); - dest_iter != dest.end (); ++dest_iter, ++src_iter) - { - if (dest_iter->first.name () != src_iter->first.name () - || dest_iter->second.second != src_iter->second.second) - fail (); - - // require argin if one path requires argin - dest_iter->second.first = dest_iter->second.first - || src_iter->second.first; - } -} - -// -------------------- jit_generator -------------------- -jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod, - tree_simple_for_command& cmd, jit_type *bounds, - const type_map& infered_types) - : tinfo (ti), module (mod), is_lvalue (false) +jit_convert::visit_octave_user_function (octave_user_function&) { - // create new vectors that include bounds - std::vector<std::string> names (infered_types.size () + 1); - std::vector<bool> argin (infered_types.size () + 1); - std::vector<jit_type *> types (infered_types.size () + 1); - names[0] = "#bounds"; - argin[0] = true; - types[0] = bounds; - size_t i; - type_map::const_iterator iter; - for (i = 1, iter = infered_types.begin (); iter != infered_types.end (); - ++i, ++iter) - { - names[i] = iter->first.name (); - argin[i] = iter->second.first; - types[i] = iter->second.second; - } - - initialize (names, argin, types); - - try - { - value var_bounds = variables["#bounds"]; - var_bounds.second = builder.CreateLoad (var_bounds.second); - emit_simple_for (cmd, var_bounds, true); - } - catch (const jit_fail_exception&) - { - function->eraseFromParent (); - function = 0; - return; - } - - finalize (names); + fail (); } void -jit_generator::visit_anon_fcn_handle (tree_anon_fcn_handle&) +jit_convert::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_function_def (tree_function_def&) { fail (); } void -jit_generator::visit_argument_list (tree_argument_list&) +jit_convert::visit_identifier (tree_identifier& ti) { - fail (); + std::string name = ti.name (); + variable_map::iterator iter = variables.find (name); + jit_value *var; + if (iter == variables.end ()) + { + octave_value var_value = ti.do_lookup (); + jit_type *var_type = jit_typeinfo::type_of (var_value); + var = entry_block->append (new jit_extract_argument (var_type, name)); + constants.push_back (var); + bounds.push_back (std::make_pair (var_type, name)); + variables[name] = var; + arguments.push_back (std::make_pair (name, true)); + } + else + var = iter->second; + + const jit_function& fn = jit_typeinfo::grab (); + result = block->append (new jit_call (fn, var)); } void -jit_generator::visit_binary_expression (tree_binary_expression& be) -{ - tree_expression *lhs = be.lhs (); - lhs->accept (*this); - value lhsv = value_stack.back (); - value_stack.pop_back (); - - tree_expression *rhs = be.rhs (); - rhs->accept (*this); - value rhsv = value_stack.back (); - value_stack.pop_back (); - - const jit_function::overload& ol - = tinfo->binary_op_overload (be.op_type (), lhsv.first, rhsv.first); - - if (! ol.function) - fail (); - - llvm::Value *result = builder.CreateCall2 (ol.function, lhsv.second, - rhsv.second); - push_value (ol.result, result); -} - -void -jit_generator::visit_break_command (tree_break_command&) -{ - fail (); -} - -void -jit_generator::visit_colon_expression (tree_colon_expression&) -{ - fail (); -} - -void -jit_generator::visit_continue_command (tree_continue_command&) -{ - fail (); -} - -void -jit_generator::visit_global_command (tree_global_command&) +jit_convert::visit_if_clause (tree_if_clause&) { fail (); } void -jit_generator::visit_persistent_command (tree_persistent_command&) -{ - fail (); -} - -void -jit_generator::visit_decl_elt (tree_decl_elt&) -{ - fail (); -} - -void -jit_generator::visit_decl_init_list (tree_decl_init_list&) +jit_convert::visit_if_command (tree_if_command&) { fail (); } void -jit_generator::visit_simple_for_command (tree_simple_for_command& cmd) -{ - if (is_lvalue) - fail (); - - tree_expression *control = cmd.control_expr (); - assert (control); // FIXME: jwe, can this be null? - - control->accept (*this); - value over = value_stack.back (); - value_stack.pop_back (); - - emit_simple_for (cmd, over, false); -} - -void -jit_generator::visit_complex_for_command (tree_complex_for_command&) +jit_convert::visit_if_command_list (tree_if_command_list&) { fail (); } void -jit_generator::visit_octave_user_script (octave_user_script&) +jit_convert::visit_index_expression (tree_index_expression&) { fail (); } void -jit_generator::visit_octave_user_function (octave_user_function&) -{ - fail (); -} - -void -jit_generator::visit_octave_user_function_header (octave_user_function&) -{ - fail (); -} - -void -jit_generator::visit_octave_user_function_trailer (octave_user_function&) +jit_convert::visit_matrix (tree_matrix&) { fail (); } void -jit_generator::visit_function_def (tree_function_def&) +jit_convert::visit_cell (tree_cell&) { fail (); } void -jit_generator::visit_identifier (tree_identifier& ti) +jit_convert::visit_multi_assignment (tree_multi_assignment&) { - std::string name = ti.name (); - value variable = variables[name]; - if (is_lvalue) - { - value_stack.push_back (variable); - - const jit_function::overload& ol = tinfo->release (variable.first); - if (ol.function) - { - llvm::Value *load = builder.CreateLoad (variable.second, name); - builder.CreateCall (ol.function, load); - } - } - else - { - llvm::Value *load = builder.CreateLoad (variable.second, name); - push_value (variable.first, load); - - const jit_function::overload& ol = tinfo->grab (variable.first); - if (ol.function) - builder.CreateCall (ol.function, load); - } + fail (); } void -jit_generator::visit_if_clause (tree_if_clause&) +jit_convert::visit_no_op_command (tree_no_op_command&) { fail (); } void -jit_generator::visit_if_command (tree_if_command& cmd) +jit_convert::visit_constant (tree_constant& tc) { - tree_if_command_list *lst = cmd.cmd_list (); - assert (lst); - lst->accept (*this); + octave_value v = tc.rvalue1 (); + if (v.is_real_scalar () && v.is_double_type ()) + { + double dv = v.double_value (); + result = get_scalar (dv); + } + else if (v.is_range ()) + fail (); + else + fail (); } void -jit_generator::visit_if_command_list (tree_if_command_list& lst) -{ - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "if_tail", function); - std::vector<llvm::BasicBlock *> clause_entry (lst.size ()); - tree_if_command_list::iterator p; - size_t i; - for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i) - { - tree_if_clause *tic = *p; - if (tic->is_else_clause ()) - clause_entry[i] = llvm::BasicBlock::Create (ctx, "else_body", function, - tail); - else - clause_entry[i] = llvm::BasicBlock::Create (ctx, "if_cond", function, - tail); - } - - builder.CreateBr (clause_entry[0]); - - for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i) - { - tree_if_clause *tic = *p; - llvm::BasicBlock *body; - if (tic->is_else_clause ()) - body = clause_entry[i]; - else - { - llvm::BasicBlock *cond = clause_entry[i]; - builder.SetInsertPoint (cond); - - tree_expression *expr = tic->condition (); - expr->accept (*this); - - // FIXME: Handle undefined case - value condv = value_stack.back (); - value_stack.pop_back (); - - const jit_function::overload& ol = tinfo->get_logically_true (condv.first); - if (! ol.function) - fail (); - - bool last = i + 1 == clause_entry.size (); - llvm::BasicBlock *next = last ? tail : clause_entry[i + 1]; - body = llvm::BasicBlock::Create (ctx, "if_body", function, tail); - - llvm::Value *is_true = builder.CreateCall (ol.function, condv.second); - builder.CreateCondBr (is_true, body, next); - } - - tree_statement_list *stmt_lst = tic->commands (); - builder.SetInsertPoint (body); - stmt_lst->accept (*this); - builder.CreateBr (tail); - } - - builder.SetInsertPoint (tail); -} - -void -jit_generator::visit_index_expression (tree_index_expression&) +jit_convert::visit_fcn_handle (tree_fcn_handle&) { fail (); } void -jit_generator::visit_matrix (tree_matrix&) +jit_convert::visit_parameter_list (tree_parameter_list&) { fail (); } void -jit_generator::visit_cell (tree_cell&) -{ - fail (); -} - -void -jit_generator::visit_multi_assignment (tree_multi_assignment&) -{ - fail (); -} - -void -jit_generator::visit_no_op_command (tree_no_op_command&) +jit_convert::visit_postfix_expression (tree_postfix_expression&) { fail (); } void -jit_generator::visit_constant (tree_constant& tc) -{ - octave_value v = tc.rvalue1 (); - llvm::LLVMContext& ctx = llvm::getGlobalContext (); - if (v.is_real_scalar () && v.is_double_type ()) - { - double dv = v.double_value (); - llvm::Value *lv = llvm::ConstantFP::get (ctx, llvm::APFloat (dv)); - push_value (tinfo->get_scalar (), lv); - } - else if (v.is_range ()) - { - Range rng = v.range_value (); - llvm::Type *range = tinfo->get_range_llvm (); - llvm::Type *scalar = tinfo->get_scalar_llvm (); - llvm::Type *index = tinfo->get_index_llvm (); - - std::vector<llvm::Constant *> values (4); - values[0] = llvm::ConstantFP::get (scalar, rng.base ()); - values[1] = llvm::ConstantFP::get (scalar, rng.limit ()); - values[2] = llvm::ConstantFP::get (scalar, rng.inc ()); - values[3] = llvm::ConstantInt::get (index, rng.nelem ()); - - llvm::StructType *llvm_range = llvm::cast<llvm::StructType>(range); - llvm::Value *lv = llvm::ConstantStruct::get (llvm_range, values); - push_value (tinfo->get_range (), lv); - } - else - fail (); -} - -void -jit_generator::visit_fcn_handle (tree_fcn_handle&) +jit_convert::visit_prefix_expression (tree_prefix_expression&) { fail (); } void -jit_generator::visit_parameter_list (tree_parameter_list&) +jit_convert::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +jit_convert::visit_return_list (tree_return_list&) { fail (); } void -jit_generator::visit_postfix_expression (tree_postfix_expression&) +jit_convert::visit_simple_assignment (tree_simple_assignment& tsa) { - fail (); -} - -void -jit_generator::visit_prefix_expression (tree_prefix_expression&) -{ - fail (); -} + // resolve rhs + tree_expression *rhs = tsa.right_hand_side (); + jit_value *rhsv = visit (rhs); -void -jit_generator::visit_return_command (tree_return_command&) -{ - fail (); -} + // resolve lhs + tree_expression *lhs = tsa.left_hand_side (); + if (! lhs->is_identifier ()) + fail (); -void -jit_generator::visit_return_list (tree_return_list&) -{ - fail (); + std::string lhs_name = lhs->name (); + do_assign (lhs_name, rhsv, tsa.print_result ()); + result = rhsv; + + if (jit_instruction *instr = dynamic_cast<jit_instruction *>(rhsv)) + instr->stash_tag (lhs_name); } void -jit_generator::visit_simple_assignment (tree_simple_assignment& tsa) -{ - if (is_lvalue) - fail (); - - // resolve rhs - tree_expression *rhs = tsa.right_hand_side (); - rhs->accept (*this); - - value rhsv = value_stack.back (); - value_stack.pop_back (); - - // resolve lhs - is_lvalue = true; - tree_expression *lhs = tsa.left_hand_side (); - lhs->accept (*this); - is_lvalue = false; - - value lhsv = value_stack.back (); - value_stack.pop_back (); - - // do assign, then keep rhs as the result - builder.CreateStore (rhsv.second, lhsv.second); - - if (tsa.print_result ()) - emit_print (lhs->name (), rhsv); - - value_stack.push_back (rhsv); -} - -void -jit_generator::visit_statement (tree_statement& stmt) +jit_convert::visit_statement (tree_statement& stmt) { tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); if (cmd) - cmd->accept (*this); + visit (cmd); else { // stolen from tree_evaluator::visit_statement @@ -1487,208 +1011,243 @@ else do_bind_ans = (! expr->is_assignment_expression ()); - expr->accept (*this); + jit_value *expr_result = visit (expr); if (do_bind_ans) - { - value rhs = value_stack.back (); - value ans = variables["ans"]; - if (ans.first != rhs.first) - fail (); - - builder.CreateStore (rhs.second, ans.second); - - if (expr->print_result ()) - emit_print ("ans", rhs); - } + do_assign ("ans", expr_result, expr->print_result ()); else if (expr->is_identifier () && expr->print_result ()) { // FIXME: ugly hack, we need to come up with a way to pass // nargout to visit_identifier - emit_print (expr->name (), value_stack.back ()); + const jit_function& fn = jit_typeinfo::print_value (); + jit_const_string *name = get_string (expr->name ()); + block->append (new jit_call (fn, name, expr_result)); } - - - value_stack.pop_back (); } } void -jit_generator::visit_statement_list (tree_statement_list& lst) +jit_convert::visit_statement_list (tree_statement_list&) { - tree_statement_list::iterator iter; - for (iter = lst.begin (); iter != lst.end (); ++iter) - { - tree_statement *stmt = *iter; - assert (stmt); // FIXME: jwe can this be null? - stmt->accept (*this); - } + fail (); } void -jit_generator::visit_switch_case (tree_switch_case&) +jit_convert::visit_switch_case (tree_switch_case&) { fail (); } void -jit_generator::visit_switch_case_list (tree_switch_case_list&) +jit_convert::visit_switch_case_list (tree_switch_case_list&) { fail (); } void -jit_generator::visit_switch_command (tree_switch_command&) +jit_convert::visit_switch_command (tree_switch_command&) { fail (); } void -jit_generator::visit_try_catch_command (tree_try_catch_command&) +jit_convert::visit_try_catch_command (tree_try_catch_command&) { fail (); } void -jit_generator::visit_unwind_protect_command (tree_unwind_protect_command&) +jit_convert::visit_unwind_protect_command (tree_unwind_protect_command&) { fail (); } void -jit_generator::visit_while_command (tree_while_command&) +jit_convert::visit_while_command (tree_while_command&) { fail (); } void -jit_generator::visit_do_until_command (tree_do_until_command&) +jit_convert::visit_do_until_command (tree_do_until_command&) { fail (); } void -jit_generator::emit_simple_for (tree_simple_for_command& cmd, value over, - bool atleast_once) +jit_convert::do_assign (const std::string& lhs, jit_value *rhs, bool print) { - if (is_lvalue) - fail (); + variable_map::iterator iter = variables.find (lhs); + if (iter == variables.end ()) + arguments.push_back (std::make_pair (lhs, false)); + else + { + const jit_function& fn = jit_typeinfo::release (); + block->append (new jit_call (fn, iter->second)); + } - jit_type *index = tinfo->get_index (); - llvm::Value *init_index = 0; - if (over.first == tinfo->get_range ()) - init_index = llvm::ConstantInt::get (index->to_llvm (), 0); - else - fail (); + variables[lhs] = rhs; - llvm::Value *llvm_index = builder.CreateAlloca (index->to_llvm (), 0, "index"); - builder.CreateStore (init_index, llvm_index); + if (print) + { + const jit_function& fn = jit_typeinfo::print_value (); + jit_const_string *name = get_string (lhs); + block->append (new jit_call (fn, name, rhs)); + } +} + +jit_value * +jit_convert::visit (tree& tee) +{ + result = 0; + tee.accept (*this); - // FIXME: Support break - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "for_body", function); - llvm::BasicBlock *cond_check = llvm::BasicBlock::Create (ctx, "for_check", function); - llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "for_tail", function); + jit_value *ret = result; + result = 0; + return ret; +} - // initialize the iter from the index - if (atleast_once) - builder.CreateBr (body); - else - builder.CreateBr (cond_check); - - builder.SetInsertPoint (body); +// -------------------- jit_convert::convert_llvm -------------------- +llvm::Function * +jit_convert::convert_llvm::convert (llvm::Module *module, + const std::vector<std::pair< std::string, bool> >& args, + const std::list<jit_block *>& blocks, + const std::list<jit_value *>& constants) +{ + jit_type *any = jit_typeinfo::get_any (); - is_lvalue = true; - tree_expression *lhs = cmd.left_hand_side (); - lhs->accept (*this); - is_lvalue = false; + // argument is an array of octave_base_value*, or octave_base_value** + llvm::Type *arg_type = any->to_llvm (); // this is octave_base_value* + arg_type = arg_type->getPointerTo (); + llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context), + arg_type, false); + llvm::Function *function = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + "foobar", module); - value lhsv = value_stack.back (); - value_stack.pop_back (); + try + { + llvm::BasicBlock *prelude = llvm::BasicBlock::Create (context, "prelude", + function); + builder.SetInsertPoint (prelude); - const jit_function::overload& index_ol = tinfo->get_simple_for_index (over.first); - llvm::Value *lindex = builder.CreateLoad (llvm_index); - llvm::Value *llvm_iter = builder.CreateCall2 (index_ol.function, over.second, lindex); - value iter(index_ol.result, llvm_iter); - builder.CreateStore (iter.second, lhsv.second); + llvm::Value *arg = function->arg_begin (); + for (size_t i = 0; i < args.size (); ++i) + { + llvm::Value *loaded_arg = builder.CreateConstInBoundsGEP1_32 (arg, i); + arguments[args[i].first] = loaded_arg; + } - tree_statement_list *lst = cmd.body (); - lst->accept (*this); + // we need to generate llvm values for constants, as these don't appear in + // a block + for (std::list<jit_value *>::const_iterator iter = constants.begin (); + iter != constants.end (); ++iter) + { + jit_value *constant = *iter; + if (! dynamic_cast<jit_instruction *> (constant)) + visit (constant); + } - llvm::Value *one = llvm::ConstantInt::get (index->to_llvm (), 1); - lindex = builder.CreateLoad (llvm_index); - lindex = builder.CreateAdd (lindex, one); - builder.CreateStore (lindex, llvm_index); - builder.CreateBr (cond_check); + std::list<jit_block *>::const_iterator biter; + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block *jblock = *biter; + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, jblock->name (), + function); + jblock->stash_llvm (block); + } + + jit_block *first = *blocks.begin (); + builder.CreateBr (first->to_llvm ()); - builder.SetInsertPoint (cond_check); - lindex = builder.CreateLoad (llvm_index); - const jit_function::overload& check_ol = tinfo->get_simple_for_check (over.first); - llvm::Value *cond = builder.CreateCall2 (check_ol.function, over.second, lindex); - builder.CreateCondBr (cond, body, tail); + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + visit (*biter); - builder.SetInsertPoint (tail); + builder.CreateRetVoid (); + } catch (const jit_fail_exception&) + { + function->eraseFromParent (); + throw; + } + + llvm::verifyFunction (*function); + + return function; } void -jit_generator::emit_print (const std::string& name, const value& v) +jit_convert::convert_llvm::visit_const_string (jit_const_string& cs) +{ + cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ())); +} + +void +jit_convert::convert_llvm::visit_const_scalar (jit_const_scalar& cs) +{ + llvm::Type *dbl = llvm::Type::getDoubleTy (context); + cs.stash_llvm (llvm::ConstantFP::get (dbl, cs.value ())); +} + +void +jit_convert::convert_llvm::visit_block (jit_block& b) { - const jit_function::overload& ol = tinfo->print_value (v.first); - if (! ol.function) - fail (); + llvm::BasicBlock *block = b.to_llvm (); + builder.SetInsertPoint (block); + for (jit_block::iterator iter = b.begin (); iter != b.end (); ++iter) + visit (*iter); +} - llvm::Value *str = builder.CreateGlobalStringPtr (name); - builder.CreateCall2 (ol.function, str, v.second); +void +jit_convert::convert_llvm::visit_break (jit_break& b) +{ + builder.CreateBr (b.sucessor_llvm ()); +} + +void +jit_convert::convert_llvm::visit_cond_break (jit_cond_break& cb) +{ + llvm::Value *cond = cb.cond_llvm (); + builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); } void -jit_generator::initialize (const std::vector<std::string>& names, - const std::vector<bool>& argin, - const std::vector<jit_type *> types) +jit_convert::convert_llvm::visit_call (jit_call& call) { - std::vector<llvm::Type *> arg_types (names.size ()); - for (size_t i = 0; i < types.size (); ++i) - arg_types[i] = types[i]->to_llvm_arg (); - - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); - llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); - function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, - "foobar", module); + const jit_function::overload& ol = call.overload (); + if (! ol.function) + fail (); + + std::vector<llvm::Value *> args (call.argument_count ()); + for (size_t i = 0; i < call.argument_count (); ++i) + args[i] = call.argument_llvm (i); - // create variables and copy initial values - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); - builder.SetInsertPoint (body); - llvm::Function::arg_iterator arg_iter = function->arg_begin(); - for (size_t i = 0; i < names.size (); ++i, ++arg_iter) - { - llvm::Type *vartype = types[i]->to_llvm (); - const std::string& name = names[i]; - llvm::Value *var = builder.CreateAlloca (vartype, 0, name); - variables[name] = value (types[i], var); - - if (argin[i] || types[i]->force_init ()) - { - llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); - builder.CreateStore (loaded_arg, var); - } - } + call.stash_llvm (builder.CreateCall (ol.function, args)); } void -jit_generator::finalize (const std::vector<std::string>& names) +jit_convert::convert_llvm::visit_extract_argument (jit_extract_argument& extract) { - // copy computed values back into arguments - // we use names instead of looping through variables because order is - // important - llvm::Function::arg_iterator arg_iter = function->arg_begin(); - for (size_t i = 0; i < names.size (); ++i, ++arg_iter) - { - llvm::Value *var = variables[names[i]].second; - llvm::Value *loaded_var = builder.CreateLoad (var); - builder.CreateStore (loaded_var, arg_iter); - } - builder.CreateRetVoid (); + const jit_function::overload& ol = extract.overload (); + if (! ol.function) + fail (); + + llvm::Value *arg = arguments[extract.tag ()]; + arg = builder.CreateLoad (arg); + extract.stash_llvm (builder.CreateCall (ol.function, arg)); +} + +void +jit_convert::convert_llvm::visit_store_argument (jit_store_argument& store) +{ + llvm::Value *arg_value = store.result_llvm (); + const jit_function::overload& ol = store.overload (); + if (! ol.function) + fail (); + + arg_value = builder.CreateCall (ol.function, arg_value); + + llvm::Value *arg = arguments[store.tag ()]; + store.stash_llvm (builder.CreateStore (arg_value, arg)); } // -------------------- tree_jit -------------------- @@ -1700,25 +1259,33 @@ } tree_jit::~tree_jit (void) -{ - delete tinfo; -} +{} bool -tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) +tree_jit::execute (tree& cmd) { if (! initialize ()) return false; - jit_type *bounds_t = tinfo->type_of (bounds); - jit_info *jinfo = cmd.get_info (bounds_t); + compiled_map::iterator iter = compiled.find (&cmd); + jit_info *jinfo = 0; + if (iter != compiled.end ()) + { + jinfo = iter->second; + if (! jinfo->match ()) + { + delete jinfo; + jinfo = 0; + } + } + if (! jinfo) { - jinfo = new jit_info (*this, cmd, bounds_t); - cmd.stash_info (bounds_t, jinfo); + jinfo = new jit_info (*this, cmd); + compiled[&cmd] = jinfo; } - return jinfo->execute (bounds); + return jinfo->execute (); } bool @@ -1746,7 +1313,7 @@ pass_manager->add (llvm::createCFGSimplificationPass ()); pass_manager->doInitialization (); - tinfo = new jit_typeinfo (module, engine); + jit_typeinfo::initialize (module, engine); return true; } @@ -1760,106 +1327,80 @@ } // -------------------- jit_info -------------------- -jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd, - jit_type *bounds) : tinfo (tjit.get_typeinfo ()), - engine (tjit.get_engine ()), - bounds_t (bounds) +jit_info::jit_info (tree_jit& tjit, tree& tee) + : engine (tjit.get_engine ()) { - jit_infer infer(tinfo); - + llvm::Function *fun = 0; try { - infer.infer (cmd, bounds); + jit_convert conv (tjit.get_module (), tee); + fun = conv.get_function (); + arguments = conv.get_arguments (); + bounds = conv.get_bounds (); } catch (const jit_fail_exception&) + {} + + if (! fun) { function = 0; return; } - types = infer.get_types (); - - jit_generator gen(tinfo, tjit.get_module (), cmd, bounds, types); - function = gen.get_function (); - - if (function) - { - if (debug_print) - { - std::cout << "Compiled code:\n"; - std::cout << cmd.str_print_code () << std::endl; - - std::cout << "Before optimization:\n"; + tjit.optimize (fun); - llvm::raw_os_ostream os (std::cout); - function->print (os); - } - llvm::verifyFunction (*function); - tjit.optimize (function); + if (debug_print) + { + std::cout << "-------------------- optimized llvm ir --------------------\n"; + llvm::raw_os_ostream llvm_cout (std::cout); + fun->print (llvm_cout); + std::cout << std::endl; + } - if (debug_print) - { - std::cout << "After optimization:\n"; - - llvm::raw_os_ostream os (std::cout); - function->print (os); - } - } + function = reinterpret_cast<jited_function>(engine->getPointerToFunction (fun)); } bool -jit_info::execute (const octave_value& bounds) const +jit_info::execute (void) const { if (! function) return false; - std::vector<llvm::GenericValue> args (types.size () + 1); - tinfo->to_generic (bounds_t, args[0], bounds); - - size_t idx; - type_map::const_iterator iter; - for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) + std::vector<octave_base_value *> real_arguments (arguments.size ()); + for (size_t i = 0; i < arguments.size (); ++i) { - if (iter->second.first) // argin? + if (arguments[i].second) { - octave_value ov = iter->first->varval (); - tinfo->to_generic (iter->second.second, args[idx], ov); + octave_value current = symbol_table::varval (arguments[i].first); + octave_base_value *obv = current.internal_rep (); + obv->grab (); + real_arguments[i] = obv; } - else - tinfo->to_generic (iter->second.second, args[idx]); } - engine->runFunction (function, args); + function (&real_arguments[0]); - for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) - { - octave_value result = tinfo->to_octave_value (iter->second.second, args[idx]); - octave_value &ref = iter->first->varref (); - ref = result; - } - - tinfo->reset_generic (); + for (size_t i = 0; i < arguments.size (); ++i) + symbol_table::varref (arguments[i].first) = real_arguments[i]; return true; } bool -jit_info::match () const +jit_info::match (void) const { - for (type_map::const_iterator iter = types.begin (); iter != types.end (); - ++iter) - + if (! function) + return true; + + for (size_t i = 0; i < bounds.size (); ++i) { - if (iter->second.first) // argin? - { - jit_type *required_type = iter->second.second; - octave_value val = iter->first->varval (); - jit_type *current_type = tinfo->type_of (val); + const std::string& arg_name = bounds[i].second; + octave_value value = symbol_table::varval (arg_name); + jit_type *type = jit_typeinfo::type_of (value); - // FIXME: should be: ! required_type->is_parent (current_type) - if (required_type != current_type) - return false; - } + // FIXME: Check for a parent relationship + if (type != bounds[i].first) + return false; } return true;