# HG changeset patch # User Jordi GutiĆ©rrez Hermoso # Date 1343342752 14400 # Node ID 7aa103a1c8aed7fccddeef9e0b3ba7c6a66ec563 # Parent 397f0d80bd47479eebffd8300047be3f158ebb9d# Parent 741d2dbcc1172c4c80d25399006c5022314d91a0 Merge in Doug's changes diff -r 397f0d80bd47 -r 7aa103a1c8ae src/jit-typeinfo.cc --- a/src/jit-typeinfo.cc Thu Jul 26 17:59:30 2012 -0400 +++ b/src/jit-typeinfo.cc Thu Jul 26 18:45:52 2012 -0400 @@ -138,6 +138,25 @@ obv->release (); } +extern "C" octave_base_value * +octave_jit_cast_any_range (jit_range *rng) +{ + Range temp (*rng); + octave_value ret (temp); + octave_base_value *rep = ret.internal_rep (); + rep->grab (); + + return rep; +} +extern "C" void +octave_jit_cast_range_any (jit_range *ret, octave_base_value *obv) +{ + + jit_range r (obv->range_value ()); + *ret = r; + obv->release (); +} + extern "C" double octave_jit_cast_scalar_any (octave_base_value *obv) { @@ -210,8 +229,8 @@ } extern "C" void -octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, - double value) +octave_jit_paren_subsasgn_impl (jit_matrix *ret, jit_matrix *mat, + octave_idx_type index, double value) { NDArray *array = mat->array; if (array->nelem () < index) @@ -221,6 +240,7 @@ data[index - 1] = value; mat->update (); + *ret = *mat; } extern "C" void @@ -1291,7 +1311,8 @@ jit_function resize_paren_subsasgn = create_function (jit_convention::external, - "octave_jit_paren_subsasgn_impl", matrix, index, scalar); + "octave_jit_paren_subsasgn_impl", matrix, matrix, index, + scalar); resize_paren_subsasgn.add_mapping (engine, &octave_jit_paren_subsasgn_impl); fn = create_function (jit_convention::internal, "octave_jit_paren_subsasgn", matrix, matrix, scalar, scalar); @@ -1336,8 +1357,8 @@ // resize on out of bounds access builder.SetInsertPoint (bounds_error); - llvm::Value *resize_result = resize_paren_subsasgn.call (builder, int_idx, - value); + llvm::Value *resize_result = resize_paren_subsasgn.call (builder, mat, + int_idx, value); builder.CreateBr (done); builder.SetInsertPoint (success); @@ -1369,6 +1390,7 @@ casts[scalar->type_id ()].stash_name ("(scalar)"); casts[complex->type_id ()].stash_name ("(complex)"); casts[matrix->type_id ()].stash_name ("(matrix)"); + casts[any->type_id ()].stash_name ("(range)"); // cast any <- matrix fn = create_function (jit_convention::external, "octave_jit_cast_any_matrix", @@ -1382,6 +1404,18 @@ fn.add_mapping (engine, &octave_jit_cast_matrix_any); casts[matrix->type_id ()].add_overload (fn); + // cast any <- range + fn = create_function (jit_convention::external, "octave_jit_cast_any_range", + any, range); + fn.add_mapping (engine, &octave_jit_cast_any_range); + casts[any->type_id ()].add_overload (fn); + + // cast range <- any + fn = create_function (jit_convention::external, "octave_jit_cast_range_any", + range, any); + fn.add_mapping (engine, &octave_jit_cast_range_any); + casts[range->type_id ()].add_overload (fn); + // cast any <- scalar fn = create_function (jit_convention::external, "octave_jit_cast_any_scalar", any, scalar); diff -r 397f0d80bd47 -r 7aa103a1c8ae src/pt-eval.cc --- a/src/pt-eval.cc Thu Jul 26 17:59:30 2012 -0400 +++ b/src/pt-eval.cc Thu Jul 26 18:45:52 2012 -0400 @@ -296,11 +296,6 @@ if (debug_mode) do_breakpoint (cmd.is_breakpoint ()); -#if HAVE_LLVM - if (jiter.execute (cmd)) - return; -#endif - // FIXME -- need to handle PARFOR loops here using cmd.in_parallel () // and cmd.maxproc_expr (); @@ -314,6 +309,11 @@ octave_value rhs = expr->rvalue1 (); +#if HAVE_LLVM + if (jiter.execute (cmd, rhs)) + return; +#endif + if (error_state || rhs.is_undefined ()) return; diff -r 397f0d80bd47 -r 7aa103a1c8ae src/pt-jit.cc --- a/src/pt-jit.cc Thu Jul 26 17:59:30 2012 -0400 +++ b/src/pt-jit.cc Thu Jul 26 18:45:52 2012 -0400 @@ -57,8 +57,9 @@ static llvm::LLVMContext& context = llvm::getGlobalContext (); // -------------------- jit_convert -------------------- -jit_convert::jit_convert (llvm::Module *module, tree &tee) - : iterator_count (0), short_count (0), breaking (false) +jit_convert::jit_convert (llvm::Module *module, tree &tee, + jit_type *for_bounds) + : iterator_count (0), for_bounds_count (0), short_count (0), breaking (false) { jit_instruction::reset_ids (); @@ -67,6 +68,10 @@ append (entry_block); entry_block->mark_alive (); block = entry_block; + + if (for_bounds) + create_variable (next_for_bounds (false), for_bounds); + visit (tee); // FIXME: Remove if we no longer only compile loops @@ -175,10 +180,7 @@ assert (boole); bool is_and = boole->op_type () == tree_boolean_expression::bool_and; - std::stringstream ss; - ss << "#short_result" << short_count++; - - std::string short_name = ss.str (); + std::string short_name = next_shortcircut_result (); jit_variable *short_result = create (short_name); vmap[short_name] = short_result; @@ -302,10 +304,9 @@ continues.clear (); // we need a variable for our iterator, because it is used in multiple blocks - std::stringstream ss; - ss << "#iter" << iterator_count++; - std::string iter_name = ss.str (); + std::string iter_name = next_iterator (); jit_variable *iterator = create (iter_name); + create (iter_name); vmap[iter_name] = iterator; jit_block *body = create ("for_body"); @@ -314,7 +315,10 @@ jit_block *tail = create ("for_tail"); // do control expression, iter init, and condition check in prev_block (block) - jit_value *control = visit (cmd.control_expr ()); + // if we are the top level for loop, the bounds is an input argument. + jit_value *control = find_variable (next_for_bounds ()); + if (! control) + control = visit (cmd.control_expr ()); jit_call *init_iter = create (jit_typeinfo::for_init, control); block->append (init_iter); block->append (create (iterator, init_iter)); @@ -762,21 +766,43 @@ } jit_variable * +jit_convert::find_variable (const std::string& vname) const +{ + vmap_t::const_iterator iter; + iter = vmap.find (vname); + return iter != vmap.end () ? iter->second : 0; +} + +jit_variable * jit_convert::get_variable (const std::string& vname) { - vmap_t::iterator iter; - iter = vmap.find (vname); - if (iter != vmap.end ()) - return iter->second; + jit_variable *ret = find_variable (vname); + if (ret) + return ret; - jit_variable *var = create (vname); octave_value val = symbol_table::find (vname); jit_type *type = jit_typeinfo::type_of (val); + return create_variable (vname, type); +} + +jit_variable * +jit_convert::create_variable (const std::string& vname, jit_type *type) +{ + jit_variable *var = create (vname); jit_extract_argument *extract; extract = create (type, var); entry_block->prepend (extract); + return vmap[vname] = var; +} - return vmap[vname] = var; +std::string +jit_convert::next_name (const char *prefix, size_t& count, bool inc) +{ + std::stringstream ss; + ss << prefix << count; + if (inc) + ++count; + return ss.str (); } std::pair @@ -1462,20 +1488,29 @@ {} bool -tree_jit::execute (tree_simple_for_command& cmd) +tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) { - if (! initialize ()) + const size_t MIN_TRIP_COUNT = 1000; + + size_t tc = trip_count (bounds); + if (! tc || ! initialize ()) return false; + jit_info::vmap extra_vars; + extra_vars["#for_bounds0"] = &bounds; + jit_info *info = cmd.get_info (); - if (! info || ! info->match ()) + if (! info || ! info->match (extra_vars)) { + if (tc < MIN_TRIP_COUNT) + return false; + delete info; - info = new jit_info (*this, cmd); + info = new jit_info (*this, cmd, bounds); cmd.stash_info (info); } - return info->execute (); + return info->execute (extra_vars); } bool @@ -1531,6 +1566,19 @@ return true; } +size_t +tree_jit::trip_count (const octave_value& bounds) const +{ + if (bounds.is_range ()) + { + Range rng = bounds.range_value (); + return rng.nelem (); + } + + // unsupported type + return 0; +} + void tree_jit::optimize (llvm::Function *fn) @@ -1548,14 +1596,12 @@ // -------------------- jit_info -------------------- jit_info::jit_info (tree_jit& tjit, tree& tee) - : engine (tjit.get_engine ()), llvm_function (0) + : engine (tjit.get_engine ()), function (0), llvm_function (0) { try { jit_convert conv (tjit.get_module (), tee); - llvm_function = conv.get_function (); - arguments = conv.get_arguments (); - bounds = conv.get_bounds (); + initialize (tjit, conv); } catch (const jit_fail_exception& e) { @@ -1564,24 +1610,24 @@ std::cout << "jit fail: " << e.what () << std::endl; #endif } - - if (! llvm_function) - { - function = 0; - return; - } - - tjit.optimize (llvm_function); +} +jit_info::jit_info (tree_jit& tjit, tree& tee, const octave_value& for_bounds) + : engine (tjit.get_engine ()), function (0), llvm_function (0) +{ + try + { + jit_convert conv (tjit.get_module (), tee, + jit_typeinfo::type_of (for_bounds)); + initialize (tjit, conv); + } + catch (const jit_fail_exception& e) + { #ifdef OCTAVE_JIT_DEBUG - std::cout << "-------------------- optimized llvm ir --------------------\n"; - llvm::raw_os_ostream llvm_cout (std::cout); - llvm_function->print (llvm_cout); - std::cout << std::endl; + if (e.known ()) + std::cout << "jit fail: " << e.what () << std::endl; #endif - - void *void_fn = engine->getPointerToFunction (llvm_function); - function = reinterpret_cast (void_fn); + } } jit_info::~jit_info (void) @@ -1591,7 +1637,7 @@ } bool -jit_info::execute (void) const +jit_info::execute (const vmap& extra_vars) const { if (! function) return false; @@ -1601,24 +1647,29 @@ { if (arguments[i].second) { - octave_value ¤t = symbol_table::varref (arguments[i].first); + octave_value current = find (extra_vars, arguments[i].first); octave_base_value *obv = current.internal_rep (); obv->grab (); real_arguments[i] = obv; - current = octave_value (); } } function (&real_arguments[0]); for (size_t i = 0; i < arguments.size (); ++i) - symbol_table::varref (arguments[i].first) = real_arguments[i]; + { + const std::string& name = arguments[i].first; + + // do not store for loop bounds temporary + if (name.size () && name[0] != '#') + symbol_table::varref (arguments[i].first) = real_arguments[i]; + } return true; } bool -jit_info::match (void) const +jit_info::match (const vmap& extra_vars) const { if (! function) return true; @@ -1626,7 +1677,7 @@ for (size_t i = 0; i < bounds.size (); ++i) { const std::string& arg_name = bounds[i].second; - octave_value value = symbol_table::find (arg_name); + octave_value value = find (extra_vars, arg_name); jit_type *type = jit_typeinfo::type_of (value); // FIXME: Check for a parent relationship @@ -1636,6 +1687,40 @@ return true; } + +void +jit_info::initialize (tree_jit& tjit, jit_convert& conv) +{ + llvm_function = conv.get_function (); + arguments = conv.get_arguments (); + bounds = conv.get_bounds (); + + if (llvm_function) + { + tjit.optimize (llvm_function); + +#ifdef OCTAVE_JIT_DEBUG + std::cout << "-------------------- optimized llvm ir " + << "--------------------\n"; + llvm::raw_os_ostream llvm_cout (std::cout); + llvm_function->print (llvm_cout); + llvm_cout.flush (); + std::cout << std::endl; +#endif + + void *void_fn = engine->getPointerToFunction (llvm_function); + function = reinterpret_cast (void_fn); + } +} + +octave_value +jit_info::find (const vmap& extra_vars, const std::string& vname) const +{ + vmap::const_iterator iter = extra_vars.find (vname); + return iter == extra_vars.end () ? symbol_table::varval (vname) + : *iter->second; +} + #endif diff -r 397f0d80bd47 -r 7aa103a1c8ae src/pt-jit.h --- a/src/pt-jit.h Thu Jul 26 17:59:30 2012 -0400 +++ b/src/pt-jit.h Thu Jul 26 18:45:52 2012 -0400 @@ -64,7 +64,7 @@ typedef std::pair type_bound; typedef std::vector type_bound_vector; - jit_convert (llvm::Module *module, tree &tee); + jit_convert (llvm::Module *module, tree &tee, jit_type *for_bounds = 0); ~jit_convert (void); @@ -245,6 +245,7 @@ std::list all_values; size_t iterator_count; + size_t for_bounds_count; size_t short_count; typedef std::map vmap_t; @@ -268,8 +269,31 @@ return ret; } + // get an existing vairable. If the variable does not exist, it will not be + // created + jit_variable *find_variable (const std::string& vname) const; + + // get a variable, create it if it does not exist. The type will default to + // the variable's current type in the symbol table. jit_variable *get_variable (const std::string& vname); + // create a variable of the given name and given type. Will also insert an + // extract statement + jit_variable *create_variable (const std::string& vname, jit_type *type); + + // The name of the next for loop iterator. If inc is false, then the iterator + // counter will not be incremented. + std::string next_iterator (bool inc = true) + { return next_name ("#iter", iterator_count, inc); } + + std::string next_for_bounds (bool inc = true) + { return next_name ("#for_bounds", for_bounds_count, inc); } + + std::string next_shortcircut_result (bool inc = true) + { return next_name ("#shortcircut_result", short_count, inc); } + + std::string next_name (const char *prefix, size_t& count, bool inc); + std::pair resolve (tree_index_expression& exp); jit_value *do_assign (tree_expression *exp, jit_value *rhs, @@ -404,7 +428,7 @@ ~tree_jit (void); - bool execute (tree_simple_for_command& cmd); + bool execute (tree_simple_for_command& cmd, const octave_value& bounds); bool execute (tree_while_command& cmd); @@ -416,6 +440,8 @@ private: bool initialize (void); + size_t trip_count (const octave_value& bounds) const; + // FIXME: Temorary hack to test typedef std::map compiled_map; llvm::Module *module; @@ -428,18 +454,27 @@ jit_info { public: + // we use a pointer here so we don't have to include ov.h + typedef std::map vmap; + jit_info (tree_jit& tjit, tree& tee); + jit_info (tree_jit& tjit, tree& tee, const octave_value& for_bounds); + ~jit_info (void); - bool execute (void) const; + bool execute (const vmap& extra_vars = vmap ()) const; - bool match (void) const; + bool match (const vmap& extra_vars = vmap ()) const; private: typedef jit_convert::type_bound type_bound; typedef jit_convert::type_bound_vector type_bound_vector; typedef void (*jited_function)(octave_base_value**); + void initialize (tree_jit& tjit, jit_convert& conv); + + octave_value find (const vmap& extra_vars, const std::string& vname) const; + llvm::ExecutionEngine *engine; jited_function function; llvm::Function *llvm_function;