# HG changeset patch # User Max Brister # Date 1347172140 21600 # Node ID 3f43e9d6d86ee9472167edfb8251d550d586ce9e # Parent 5fff79162342c74688df37525a137173fbbd68b1 JIT compile anonymous functions * jit-ir.h (jit_block::front, jit_block::back): New function. (jit_call::jit_call): New overloads. (jit_return): New class. * jit-typeinfo.cc (octave_jit_create_undef): New function. (jit_operation::to_idx): Correctly handle empty type vector. (jit_typeinfo::jit_typeinfo): Add destroy_fn and initialize create_undef. * jit-typeinfo.h (jit_typeinfo::get_any_ptr, jit_typeinfo::destroy, jit_typeinfo::create_undef): New function. * pt-jit.cc (jit_convert::jit_convert): Add overload and refactor. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute, jit_function_info::jit_function_info, jit_function_info::execute, jit_function_info::match): New function. (jit_convert::get_variable): Support function variable lookup. (jit_convert_llvm::convert): Handle loop/function agnostic stuff. (jit_convert_llvm::visit): Handle function creation as well. (tree_jit::execute): Move implementation to tree_jit::do_execute. (jit_info::compile): Call convert_loop instead of convert. * pt-jit.h (jit_convert::jit_convert): New overload. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute): New function. (jit_convert::create_variable, jit_convert_llvm::initialize): Update signature. (tree_jit::execute): Made static. (tree_jit::tree_jit): Made private. (jit_function_info): New class. * ov-usr-fcn.cc (octave_user_function::~octave_user_function): Delete jit_info. (octave_user_function::octave_user_function): Maybe JIT and use is_special_expr and special_expr. (octave_user_function::special_expr): New function. * ov-usr-fcn.h (octave_user_function::is_special_expr, octave_user_function::special_expr, octave_user_function::get_info, octave_user_function::stash_info): New function. * pt-decl.h (tree_decl_elt::name): New function. * pt-eval.cc (tree_evaluator::visit_simple_for_command, tree_evaluator::visit_while_command): Use static tree_jit methods. diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/interp-core/jit-ir.h --- a/libinterp/interp-core/jit-ir.h Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/interp-core/jit-ir.h Sun Sep 09 00:29:00 2012 -0600 @@ -42,6 +42,7 @@ JIT_METH(call); \ JIT_METH(extract_argument); \ JIT_METH(store_argument); \ + JIT_METH(return); \ JIT_METH(phi); \ JIT_METH(variable); \ JIT_METH(error_check); \ @@ -768,6 +769,10 @@ return true; } + jit_instruction *front (void) { return instructions.front (); } + + jit_instruction *back (void) { return instructions.back (); } + JIT_VALUE_ACCEPT; private: void internal_append (jit_instruction *instr); @@ -1149,6 +1154,21 @@ jit_call : public jit_instruction { public: + jit_call (const jit_operation& (*aoperation) (void)) + : moperation (aoperation ()) + { + const jit_function& ol = overload (); + if (ol.valid ()) + stash_type (ol.result ()); + } + + jit_call (const jit_operation& aoperation) : moperation (aoperation) + { + const jit_function& ol = overload (); + if (ol.valid ()) + stash_type (ol.result ()); + } + #define JIT_CALL_CONST(N) \ jit_call (const jit_operation& aoperation, \ OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ @@ -1366,6 +1386,38 @@ }; class +jit_return : public jit_instruction +{ +public: + jit_return (void) {} + + jit_return (jit_value *retval) : jit_instruction (retval) {} + + jit_value *result (void) const + { + return argument_count () ? argument (0) : 0; + } + + jit_type *result_type (void) const + { + jit_value *res = result (); + return res ? res->type () : 0; + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "return"; + + if (result ()) + os << " " << *result (); + + return os; + } + + JIT_VALUE_ACCEPT; +}; + +class jit_ir_walker { public: diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/interp-core/jit-typeinfo.cc --- a/libinterp/interp-core/jit-typeinfo.cc Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/interp-core/jit-typeinfo.cc Sun Sep 09 00:29:00 2012 -0600 @@ -366,6 +366,16 @@ return idx < ndim ? mat->dimensions[idx] : 1; } +extern "C" octave_base_value * +octave_jit_create_undef (void) +{ + octave_value undef; + octave_base_value *ret = undef.internal_rep (); + ret->grab (); + + return ret; +} + extern "C" Complex octave_jit_complex_div (Complex lhs, Complex rhs) { @@ -791,14 +801,15 @@ jit_operation::to_idx (const std::vector& types) const { octave_idx_type numel = types.size (); - if (numel == 1) - numel = 2; + numel = std::max (2, numel); Array idx (dim_vector (1, numel)); for (octave_idx_type i = 0; i < static_cast (types.size ()); ++i) idx(i) = types[i]->type_id (); + if (types.size () == 0) + idx(0) = idx(1) = 0; if (types.size () == 1) { idx(1) = idx(0); @@ -1149,6 +1160,14 @@ fn.add_mapping (engine, &octave_jit_release_matrix); release_fn.add_overload (fn); + // destroy + destroy_fn = release_fn; + destroy_fn.stash_name ("destroy"); + destroy_fn.add_overload (create_identity(scalar)); + destroy_fn.add_overload (create_identity(boolean)); + destroy_fn.add_overload (create_identity(index)); + destroy_fn.add_overload (create_identity(complex)); + // now for binary scalar operations add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); @@ -1702,6 +1721,12 @@ scalar, matrix, index, index); end_fn.add_overload (fn); + // -------------------- create_undef -------------------- + create_undef_fn.stash_name ("create_undef"); + fn = create_function (jit_convention::external, "octave_jit_create_undef", + any); + create_undef_fn.add_overload (fn); + casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); casts[complex->type_id ()].stash_name ("(complex)"); diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/interp-core/jit-typeinfo.h --- a/libinterp/interp-core/jit-typeinfo.h Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/interp-core/jit-typeinfo.h Sun Sep 09 00:29:00 2012 -0600 @@ -452,6 +452,8 @@ static jit_type *get_scalar_ptr (void) { return instance->scalar_ptr; } + static jit_type *get_any_ptr (void) { return instance->any_ptr; } + static jit_type *get_range (void) { return instance->range; } static jit_type *get_string (void) { return instance->string; } @@ -498,6 +500,11 @@ return instance->release_fn.overload (type); } + static const jit_operation& destroy (void) + { + return instance->destroy_fn; + } + static const jit_operation& print_value (void) { return instance->print_fn; @@ -563,6 +570,11 @@ { return instance->do_end (value, index, count); } + + static const jit_operation& create_undef (void) + { + return instance->create_undef_fn; + } private: jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); @@ -751,6 +763,7 @@ std::vector unary_ops; jit_operation grab_fn; jit_operation release_fn; + jit_operation destroy_fn; jit_operation print_fn; jit_operation for_init_fn; jit_operation for_check_fn; @@ -761,6 +774,7 @@ jit_paren_subsasgn paren_subsasgn_fn; jit_operation end1_fn; jit_operation end_fn; + jit_operation create_undef_fn; jit_function any_call; diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/interp-core/pt-jit.cc --- a/libinterp/interp-core/pt-jit.cc Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/interp-core/pt-jit.cc Sun Sep 09 00:29:00 2012 -0600 @@ -65,22 +65,16 @@ // -------------------- jit_convert -------------------- jit_convert::jit_convert (tree &tee, jit_type *for_bounds) - : iterator_count (0), for_bounds_count (0), short_count (0), breaking (false) + : converting_function (false) { - jit_instruction::reset_ids (); - - entry_block = factory.create ("body"); - final_block = factory.create ("final"); - blocks.push_back (entry_block); - entry_block->mark_alive (); - block = entry_block; + initialize (symbol_table::current_scope ()); if (for_bounds) create_variable (next_for_bounds (false), for_bounds); visit (tee); - // FIXME: Remove if we no longer only compile loops + // breaks must have been handled by the top level loop assert (! breaking); assert (breaks.empty ()); assert (continues.empty ()); @@ -95,6 +89,91 @@ if (name.size () && name[0] != '#') final_block->append (factory.create (var)); } + + final_block->append (factory.create ()); +} + +jit_convert::jit_convert (octave_user_function& fcn, + const std::vector& args) + : converting_function (true) +{ + initialize (fcn.scope ()); + + tree_parameter_list *plist = fcn.parameter_list (); + tree_parameter_list *rlist = fcn.return_list (); + if (plist && plist->takes_varargs ()) + throw jit_fail_exception ("varags not supported"); + + if (rlist && (rlist->size () > 1 || rlist->takes_varargs ())) + throw jit_fail_exception ("multiple returns not supported"); + + if (plist) + { + tree_parameter_list::iterator piter = plist->begin (); + for (size_t i = 0; i < args.size (); ++i, ++piter) + { + if (piter == plist->end ()) + throw jit_fail_exception ("Too many parameter to function"); + + tree_decl_elt *elt = *piter; + std::string name = elt->name (); + create_variable (name, args[i]); + } + } + + jit_value *return_value = 0; + if (fcn.is_special_expr ()) + { + tree_expression *expr = fcn.special_expr (); + if (expr) + { + jit_variable *retvar = get_variable ("#return"); + jit_value *retval = visit (expr); + block->append (factory.create (retvar, retval)); + return_value = retvar; + } + } + else + visit_statement_list (*fcn.body ()); + + // the user may use break or continue to exit the function. Because the + // function does not start as a loop, we can have one continue, one break, or + // a regular fallthrough to exit the function + if (continues.size ()) + { + assert (! continues.size ()); + finish_breaks (final_block, continues); + } + else if (breaks.size ()) + finish_breaks (final_block, breaks); + else + block->append (factory.create (final_block)); + blocks.push_back (final_block); + block = final_block; + + if (! return_value && rlist && rlist->size () == 1) + { + tree_decl_elt *elt = rlist->front (); + return_value = get_variable (elt->name ()); + } + + // FIXME: We should use live range analysis to delete variables where needed. + // For now we just delete everything at the end of the function. + for (variable_map::iterator iter = vmap.begin (); iter != vmap.end (); ++iter) + { + if (iter->second != return_value) + { + jit_call *call; + call = factory.create (&jit_typeinfo::destroy, + iter->second); + final_block->append (call); + } + } + + if (return_value) + final_block->append (factory.create (return_value)); + else + final_block->append (factory.create ()); } void @@ -719,6 +798,23 @@ throw jit_fail_exception (); } +void +jit_convert::initialize (symbol_table::scope_id s) +{ + scope = s; + iterator_count = 0; + for_bounds_count = 0; + short_count = 0; + breaking = false; + jit_instruction::reset_ids (); + + entry_block = factory.create ("body"); + final_block = factory.create ("final"); + blocks.push_back (entry_block); + entry_block->mark_alive (); + block = entry_block; +} + jit_call * jit_convert::create_checked_impl (jit_call *ret) { @@ -749,20 +845,42 @@ if (ret) return ret; - octave_value val = symbol_table::find (vname); - jit_type *type = jit_typeinfo::type_of (val); - bounds.push_back (type_bound (type, vname)); + symbol_table::symbol_record record = symbol_table::find_symbol (vname, scope); + if (record.is_persistent () || record.is_global ()) + throw jit_fail_exception ("Persistent and global not yet supported"); - return create_variable (vname, type); + if (converting_function) + return create_variable (vname, jit_typeinfo::get_any (), false); + else + { + octave_value val = record.varval (); + jit_type *type = jit_typeinfo::type_of (val); + bounds.push_back (type_bound (type, vname)); + + return create_variable (vname, type); + } } jit_variable * -jit_convert::create_variable (const std::string& vname, jit_type *type) +jit_convert::create_variable (const std::string& vname, jit_type *type, + bool isarg) { jit_variable *var = factory.create (vname); - jit_extract_argument *extract; - extract = factory.create (type, var); - entry_block->prepend (extract); + + if (isarg) + { + jit_extract_argument *extract; + extract = factory.create (type, var); + entry_block->prepend (extract); + } + else + { + jit_call *init = factory.create (&jit_typeinfo::create_undef); + jit_assign *assign = factory.create (var, init); + entry_block->prepend (assign); + entry_block->prepend (init); + } + return vmap[vname] = var; } @@ -898,10 +1016,12 @@ // -------------------- jit_convert_llvm -------------------- llvm::Function * -jit_convert_llvm::convert (llvm::Module *module, - const jit_block_list& blocks, - const std::list& constants) +jit_convert_llvm::convert_loop (llvm::Module *module, + const jit_block_list& blocks, + const std::list& constants) { + converting_function = false; + // for now just init arguments from entry, later we will have to do something // more interesting jit_block *entry_block = blocks.front (); @@ -934,44 +1054,7 @@ arguments[argument_vec[i].first] = loaded_arg; } - std::list::const_iterator biter; - for (biter = blocks.begin (); biter != blocks.end (); ++biter) - { - jit_block *jblock = *biter; - llvm::BasicBlock *block = llvm::BasicBlock::Create (context, - jblock->name (), - function); - jblock->stash_llvm (block); - } - - jit_block *first = *blocks.begin (); - builder.CreateBr (first->to_llvm ()); - - // constants aren't in the IR, we visit those first - for (std::list::const_iterator iter = constants.begin (); - iter != constants.end (); ++iter) - if (! isa (*iter)) - visit (*iter); - - // convert all instructions - for (biter = blocks.begin (); biter != blocks.end (); ++biter) - visit (*biter); - - // now finish phi nodes - for (biter = blocks.begin (); biter != blocks.end (); ++biter) - { - jit_block& block = **biter; - for (jit_block::iterator piter = block.begin (); - piter != block.end () && isa (*piter); ++piter) - { - jit_instruction *phi = *piter; - finish_phi (static_cast (phi)); - } - } - - jit_block *last = blocks.back (); - builder.SetInsertPoint (last->to_llvm ()); - builder.CreateRetVoid (); + convert (blocks, constants); } catch (const jit_fail_exception& e) { function->eraseFromParent (); @@ -981,6 +1064,92 @@ return function; } + +jit_function +jit_convert_llvm::convert_function (llvm::Module *module, + const jit_block_list& blocks, + const std::list& constants, + octave_user_function& fcn, + const std::vector& args) +{ + converting_function = true; + + jit_block *final_block = blocks.back (); + jit_return *ret = dynamic_cast (final_block->back ()); + assert (ret); + + jit_function creating = jit_function (module, jit_convention::internal, + "foobar", ret->result_type (), args); + function = creating.to_llvm (); + + try + { + prelude = creating.new_block ("prelude"); + builder.SetInsertPoint (prelude); + + tree_parameter_list *plist = fcn.parameter_list (); + if (plist) + { + tree_parameter_list::iterator piter = plist->begin (); + tree_parameter_list::iterator pend = plist->end (); + for (size_t i = 0; i < args.size () && piter != pend; ++i, ++piter) + { + tree_decl_elt *elt = *piter; + std::string arg_name = elt->name (); + arguments[arg_name] = creating.argument (builder, i); + } + } + + convert (blocks, constants); + } catch (const jit_fail_exception& e) + { + function->eraseFromParent (); + throw; + } + + return creating; +} + +void +jit_convert_llvm::convert (const jit_block_list& blocks, + const std::list& constants) +{ + std::list::const_iterator biter; + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block *jblock = *biter; + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, + jblock->name (), + function); + jblock->stash_llvm (block); + } + + jit_block *first = *blocks.begin (); + builder.CreateBr (first->to_llvm ()); + + // constants aren't in the IR, we visit those first + for (std::list::const_iterator iter = constants.begin (); + iter != constants.end (); ++iter) + if (! isa (*iter)) + visit (*iter); + + // convert all instructions + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + visit (*biter); + + // now finish phi nodes + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block& block = **biter; + for (jit_block::iterator piter = block.begin (); + piter != block.end () && isa (*piter); ++piter) + { + jit_instruction *phi = *piter; + finish_phi (static_cast (phi)); + } + } +} + void jit_convert_llvm::finish_phi (jit_phi *phi) { @@ -1089,10 +1258,16 @@ { llvm::Value *arg = arguments[extract.name ()]; assert (arg); - arg = builder.CreateLoad (arg); - const jit_function& ol = extract.overload (); - extract.stash_llvm (ol.call (builder, arg)); + if (converting_function) + extract.stash_llvm (arg); + else + { + arg = builder.CreateLoad (arg); + + const jit_function& ol = extract.overload (); + extract.stash_llvm (ol.call (builder, arg)); + } } void @@ -1105,6 +1280,16 @@ } void +jit_convert_llvm::visit (jit_return& ret) +{ + jit_value *res = ret.result (); + if (res) + builder.CreateRet (res->to_llvm ()); + else + builder.CreateRetVoid (); +} + +void jit_convert_llvm::visit (jit_phi& phi) { // we might not have converted all incoming branches, so we don't @@ -1539,44 +1724,27 @@ bool tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) { - const size_t MIN_TRIP_COUNT = 1000; - - size_t tc = trip_count (bounds); - if (! tc || ! initialize ()) - return false; - - jit_info::vmap extra_vars; - extra_vars["#for_bounds0"] = &bounds; - - jit_info *info = cmd.get_info (); - if (! info || ! info->match (extra_vars)) - { - if (tc < MIN_TRIP_COUNT) - return false; - - delete info; - info = new jit_info (*this, cmd, bounds); - cmd.stash_info (info); - } - - return info->execute (extra_vars); + return instance ().do_execute (cmd, bounds); } bool tree_jit::execute (tree_while_command& cmd) { - if (! initialize ()) - return false; + return instance ().do_execute (cmd); +} - jit_info *info = cmd.get_info (); - if (! info || ! info->match ()) - { - delete info; - info = new jit_info (*this, cmd); - cmd.stash_info (info); - } +bool +tree_jit::execute (octave_user_function& fcn, const octave_value_list& args, + octave_value_list& retval) +{ + return instance ().do_execute (fcn, args, retval); +} - return info->execute (); +tree_jit& +tree_jit::instance (void) +{ + static tree_jit ret; + return ret; } bool @@ -1616,6 +1784,67 @@ return true; } +bool +tree_jit::do_execute (tree_simple_for_command& cmd, const octave_value& bounds) +{ + const size_t MIN_TRIP_COUNT = 1000; + + size_t tc = trip_count (bounds); + if (! tc || ! initialize ()) + return false; + + jit_info::vmap extra_vars; + extra_vars["#for_bounds0"] = &bounds; + + jit_info *info = cmd.get_info (); + if (! info || ! info->match (extra_vars)) + { + if (tc < MIN_TRIP_COUNT) + return false; + + delete info; + info = new jit_info (*this, cmd, bounds); + cmd.stash_info (info); + } + + return info->execute (extra_vars); +} + +bool +tree_jit::do_execute (tree_while_command& cmd) +{ + if (! initialize ()) + return false; + + jit_info *info = cmd.get_info (); + if (! info || ! info->match ()) + { + delete info; + info = new jit_info (*this, cmd); + cmd.stash_info (info); + } + + return info->execute (); +} + +bool +tree_jit::do_execute (octave_user_function& fcn, const octave_value_list& args, + octave_value_list& retval) +{ + if (! initialize ()) + return false; + + jit_function_info *info = fcn.get_info (); + if (! info || ! info->match (args)) + { + delete info; + info = new jit_function_info (*this, fcn, args); + fcn.stash_info (info); + } + + return info->execute (args, retval); +} + size_t tree_jit::trip_count (const octave_value& bounds) const { @@ -1644,6 +1873,163 @@ #endif } +// -------------------- jit_function_info -------------------- +jit_function_info::jit_function_info (tree_jit& tjit, + octave_user_function& fcn, + const octave_value_list& ov_args) + : argument_types (ov_args.length ()), function (0) +{ + size_t nargs = ov_args.length (); + for (size_t i = 0; i < nargs; ++i) + argument_types[i] = jit_typeinfo::type_of (ov_args(i)); + + try + { + jit_convert conv (fcn, argument_types); + jit_infer infer (conv.get_factory (), conv.get_blocks (), + conv.get_variable_map ()); + infer.infer (); + +#if OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + jit_block_list& blocks = infer.get_blocks (); + jit_block *entry_block = blocks.front (); + entry_block->label (); + std::cout << "-------------------- Compiling function "; + std::cout << "--------------------\n"; + + tree_print_code tpc (std::cout); + tpc.visit_octave_user_function_header (fcn); + tpc.visit_statement_list (*fcn.body ()); + tpc.visit_octave_user_function_trailer (fcn); + blocks.print (std::cout, "octave jit ir"); + } +#endif + + jit_factory& factory = conv.get_factory (); + llvm::Module *module = tjit.get_module (); + jit_convert_llvm to_llvm; + jit_function raw_fn = to_llvm.convert_function (module, + infer.get_blocks (), + factory.constants (), + fcn, argument_types); + +#ifdef OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + std::cout << "-------------------- raw function "; + std::cout << "--------------------\n"; + std::cout << *raw_fn.to_llvm () << std::endl; + } +#endif + + std::string wrapper_name = fcn.name () + "_wrapper"; + jit_type *any_t = jit_typeinfo::get_any (); + std::vector wrapper_args (1, jit_typeinfo::get_any_ptr ()); + jit_function wrapper (module, jit_convention::internal, wrapper_name, + any_t, wrapper_args); + llvm::BasicBlock *wrapper_body = wrapper.new_block (); + builder.SetInsertPoint (wrapper_body); + + llvm::Value *wrapper_arg = wrapper.argument (builder, 0); + std::vector raw_args (nargs); + for (size_t i = 0; i < nargs; ++i) + { + llvm::Value *arg; + arg = builder.CreateConstInBoundsGEP1_32 (wrapper_arg, i); + arg = builder.CreateLoad (arg); + + jit_type *arg_type = argument_types[i]; + const jit_function& cast = jit_typeinfo::cast (arg_type, any_t); + raw_args[i] = cast.call (builder, arg); + } + + llvm::Value *result = raw_fn.call (builder, raw_args); + if (raw_fn.result ()) + { + jit_type *raw_result_t = raw_fn.result (); + const jit_function& cast = jit_typeinfo::cast (any_t, raw_result_t); + result = cast.call (builder, result); + } + else + { + llvm::Value *zero = builder.getInt32 (0); + result = builder.CreateBitCast (zero, any_t->to_llvm ()); + } + + wrapper.do_return (builder, result); + + llvm::Function *llvm_function = wrapper.to_llvm (); + tjit.optimize (llvm_function); + +#ifdef OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + std::cout << "-------------------- optimized and wrapped "; + std::cout << "--------------------\n"; + std::cout << *llvm_function << std::endl; + } +#endif + + llvm::ExecutionEngine* engine = tjit.get_engine (); + void *void_fn = engine->getPointerToFunction (llvm_function); + function = reinterpret_cast (void_fn); + } + catch (const jit_fail_exception& e) + { + argument_types.clear (); +#ifdef OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + if (e.known ()) + std::cout << "jit fail: " << e.what () << std::endl; + } +#endif + } +} + +bool +jit_function_info::execute (const octave_value_list& ov_args, + octave_value_list& retval) const +{ + if (! function) + return false; + + // TODO figure out a way to delete ov_args so we avoid duplicating refcount + size_t nargs = ov_args.length (); + std::vector args (nargs); + for (size_t i = 0; i < nargs; ++i) + { + octave_base_value *obv = ov_args(i).internal_rep (); + obv->grab (); + args[i] = obv; + } + + octave_base_value *ret = function (&args[0]); + if (ret) + retval(0) = octave_value (ret); + + return true; +} + +bool +jit_function_info::match (const octave_value_list& ov_args) const +{ + if (! function) + return true; + + size_t nargs = ov_args.length (); + if (nargs != argument_types.size ()) + return false; + + for (size_t i = 0; i < nargs; ++i) + if (jit_typeinfo::type_of (ov_args(i)) != argument_types[i]) + return false; + + return true; +} + // -------------------- jit_info -------------------- jit_info::jit_info (tree_jit& tjit, tree& tee) : engine (tjit.get_engine ()), function (0), llvm_function (0) @@ -1739,8 +2125,9 @@ jit_factory& factory = conv.get_factory (); jit_convert_llvm to_llvm; - llvm_function = to_llvm.convert (tjit.get_module (), infer.get_blocks (), - factory.constants ()); + llvm_function = to_llvm.convert_loop (tjit.get_module (), + infer.get_blocks (), + factory.constants ()); arguments = to_llvm.get_arguments (); bounds = conv.get_bounds (); } @@ -2126,4 +2513,13 @@ %!error (test_undef); +%!shared id +%! id = @(x) x; + +%!assert (id (1), 1); +%!assert (id (1+1i), 1+1i) +%!assert (id (1, 2), 1) +%!error (id ()) + + */ diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/interp-core/pt-jit.h --- a/libinterp/interp-core/pt-jit.h Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/interp-core/pt-jit.h Sun Sep 09 00:29:00 2012 -0600 @@ -26,8 +26,10 @@ #ifdef HAVE_LLVM #include "jit-ir.h" +#include "pt-walk.h" +#include "symtab.h" -#include "pt-walk.h" +class octave_value_list; // Convert from the parse tree (AST) to the low level Octave IR. class @@ -40,6 +42,8 @@ jit_convert (tree &tee, jit_type *for_bounds = 0); + jit_convert (octave_user_function& fcn, const std::vector& args); + #define DECL_ARG(n) const ARG ## n& arg ## n #define JIT_CREATE_CHECKED(N) \ template \ @@ -156,6 +160,11 @@ std::vector > arguments; type_bound_vector bounds; + bool converting_function; + + // the scope of the function we are converting, or the current scope + symbol_table::scope_id scope; + jit_factory factory; // used instead of return values from visit_* functions @@ -179,6 +188,8 @@ variable_map vmap; + void initialize (symbol_table::scope_id s); + jit_call *create_checked_impl (jit_call *ret); // get an existing vairable. If the variable does not exist, it will not be @@ -191,7 +202,8 @@ // create a variable of the given name and given type. Will also insert an // extract statement - jit_variable *create_variable (const std::string& vname, jit_type *type); + jit_variable *create_variable (const std::string& vname, jit_type *type, + bool isarg = true); // The name of the next for loop iterator. If inc is false, then the iterator // counter will not be incremented. @@ -233,10 +245,17 @@ jit_convert_llvm : public jit_ir_walker { public: - llvm::Function *convert (llvm::Module *module, - const jit_block_list& blocks, - const std::list& constants); + llvm::Function *convert_loop (llvm::Module *module, + const jit_block_list& blocks, + const std::list& constants); + jit_function convert_function (llvm::Module *module, + const jit_block_list& blocks, + const std::list& constants, + octave_user_function& fcn, + const std::vector& args); + + // arguments to the llvm::Function for loops const std::vector >& get_arguments(void) const { return argument_vec; } @@ -247,13 +266,22 @@ #undef JIT_METH private: + // name -> argument index (used for compiling functions) + std::map argument_index; + std::vector > argument_vec; - // name -> llvm argument + // name -> llvm argument (used for compiling loops) std::map arguments; + + bool converting_function; + llvm::Function *function; llvm::BasicBlock *prelude; + void convert (const jit_block_list& blocks, + const std::list& constants); + void finish_phi (jit_phi *phi); void visit (jit_value *jvalue) @@ -319,13 +347,15 @@ tree_jit { public: - tree_jit (void); - ~tree_jit (void); - bool execute (tree_simple_for_command& cmd, const octave_value& bounds); + static bool execute (tree_simple_for_command& cmd, + const octave_value& bounds); - bool execute (tree_while_command& cmd); + static bool execute (tree_while_command& cmd); + + static bool execute (octave_user_function& fcn, const octave_value_list& args, + octave_value_list& retval); llvm::ExecutionEngine *get_engine (void) const { return engine; } @@ -333,8 +363,19 @@ void optimize (llvm::Function *fn); private: + tree_jit (void); + + static tree_jit& instance (void); + bool initialize (void); + bool do_execute (tree_simple_for_command& cmd, const octave_value& bounds); + + bool do_execute (tree_while_command& cmd); + + bool do_execute (octave_user_function& fcn, const octave_value_list& args, + octave_value_list& retval); + size_t trip_count (const octave_value& bounds) const; llvm::Module *module; @@ -344,6 +385,24 @@ }; class +jit_function_info +{ +public: + jit_function_info (tree_jit& tjit, octave_user_function& fcn, + const octave_value_list& ov_args); + + bool execute (const octave_value_list& ov_args, + octave_value_list& retval) const; + + bool match (const octave_value_list& ov_args) const; +private: + typedef octave_base_value *(*jited_function)(octave_base_value**); + + std::vector argument_types; + jited_function function; +}; + +class jit_info { public: diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/octave-value/ov-usr-fcn.cc --- a/libinterp/octave-value/ov-usr-fcn.cc Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/octave-value/ov-usr-fcn.cc Sun Sep 09 00:29:00 2012 -0600 @@ -39,6 +39,7 @@ #include "ov.h" #include "pager.h" #include "pt-eval.h" +#include "pt-jit.h" #include "pt-jump.h" #include "pt-misc.h" #include "pt-pr-code.h" @@ -192,6 +193,9 @@ class_constructor (false), class_method (false), parent_scope (-1), local_scope (sid), curr_unwind_protect_frame (0) +#ifdef HAVE_LLVM + , jit_info (0) +#endif { if (cmd_list) cmd_list->mark_as_function_body (); @@ -208,6 +212,10 @@ delete lead_comm; delete trail_comm; +#ifdef HAVE_LLVM + delete jit_info; +#endif + symbol_table::erase_scope (local_scope); } @@ -372,6 +380,12 @@ if (! cmd_list) return retval; +#ifdef HAVE_LLVM + if (Venable_jit_compiler && is_special_expr () + && tree_jit::execute (*this, args, retval)) + return retval; +#endif + int nargin = args.length (); unwind_protect frame; @@ -457,23 +471,14 @@ frame.protect_var (tree_evaluator::statement_context); tree_evaluator::statement_context = tree_evaluator::function; - bool special_expr = (is_inline_function () || is_anonymous_function ()); - BEGIN_PROFILER_BLOCK (profiler_name ()) - if (special_expr) + if (is_special_expr ()) { - assert (cmd_list->length () == 1); - - tree_statement *stmt = 0; + tree_expression *expr = special_expr (); - if ((stmt = cmd_list->front ()) - && stmt->is_expression ()) - { - tree_expression *expr = stmt->expression (); - - retval = expr->rvalue (nargout); - } + if (expr) + retval = expr->rvalue (nargout); } else cmd_list->accept (*current_evaluator); @@ -497,7 +502,7 @@ // Copy return values out. - if (ret_list && ! special_expr) + if (ret_list && ! is_special_expr ()) { ret_list->initialize_undefined_elements (my_name, nargout, Matrix ()); @@ -529,6 +534,16 @@ tw.visit_octave_user_function (*this); } +tree_expression * +octave_user_function::special_expr (void) +{ + assert (is_special_expr ()); + assert (cmd_list->length () == 1); + + tree_statement *stmt = cmd_list->front (); + return stmt->expression (); +} + bool octave_user_function::subsasgn_optimization_ok (void) { diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/octave-value/ov-usr-fcn.h --- a/libinterp/octave-value/ov-usr-fcn.h Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/octave-value/ov-usr-fcn.h Sun Sep 09 00:29:00 2012 -0600 @@ -41,8 +41,13 @@ class tree_parameter_list; class tree_statement_list; class tree_va_return_list; +class tree_expression; class tree_walker; +#ifdef HAVE_LLVM +class jit_function_info; +#endif + class octave_user_code : public octave_function { @@ -283,6 +288,14 @@ : false; } + // If we are a special expression, then the function body consists of exactly + // one expression. The expression's result is the return value of the + // function. + bool is_special_expr (void) const + { + return is_inline_function () || is_anonymous_function (); + } + bool is_nested_function (void) const { return nested_function; } void mark_as_nested_function (void) { nested_function = true; } @@ -335,6 +348,10 @@ octave_comment_list *trailing_comment (void) { return trail_comm; } + // If is_special_expr is true, retrieve the sigular expression that forms the + // body. May be null (even if is_special_expr is true). + tree_expression *special_expr (void); + bool subsasgn_optimization_ok (void); void accept (tree_walker& tw); @@ -351,6 +368,12 @@ return false; } +#ifdef HAVE_LLVM + jit_function_info *get_info (void) { return jit_info; } + + void stash_info (jit_function_info *info) { jit_info = info; } +#endif + #if 0 void print_symtab_info (std::ostream& os) const; #endif @@ -427,6 +450,10 @@ // pointer to the current unwind_protect frame of this function. unwind_protect *curr_unwind_protect_frame; +#ifdef HAVE_LLVM + jit_function_info *jit_info; +#endif + #if 0 // The symbol record for argn in the local symbol table. octave_value& argn_varref; diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/parse-tree/pt-decl.h --- a/libinterp/parse-tree/pt-decl.h Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/parse-tree/pt-decl.h Sun Sep 09 00:29:00 2012 -0600 @@ -84,6 +84,8 @@ tree_identifier *ident (void) { return id; } + std::string name (void) { return id ? id->name () : ""; } + tree_expression *expression (void) { return expr; } tree_decl_elt *dup (symbol_table::scope_id scope, diff -r 5fff79162342 -r 3f43e9d6d86e libinterp/parse-tree/pt-eval.cc --- a/libinterp/parse-tree/pt-eval.cc Sat Sep 08 18:47:29 2012 -0700 +++ b/libinterp/parse-tree/pt-eval.cc Sun Sep 09 00:29:00 2012 -0600 @@ -47,10 +47,6 @@ //FIXME: This should be part of tree_evaluator #include "pt-jit.h" -#if HAVE_LLVM -static tree_jit jiter; -#endif - static tree_evaluator std_evaluator; tree_evaluator *current_evaluator = &std_evaluator; @@ -311,7 +307,7 @@ octave_value rhs = expr->rvalue1 (); #if HAVE_LLVM - if (Venable_jit_compiler && jiter.execute (cmd, rhs)) + if (Venable_jit_compiler && tree_jit::execute (cmd, rhs)) return; #endif @@ -1048,7 +1044,7 @@ return; #if HAVE_LLVM - if (Venable_jit_compiler && jiter.execute (cmd)) + if (Venable_jit_compiler && tree_jit::execute (cmd)) return; #endif