Mercurial > octave
view src/pt-jit.cc @ 14901:516b4a15b775
doc: Copyright fix in pt-jit.h and pt-jit.cc
author | Max Brister <max@2bass.com> |
---|---|
date | Mon, 07 May 2012 18:37:31 -0600 |
parents | f25d2224fa02 |
children | 54ea692b8ab5 |
line wrap: on
line source
/* Copyright (C) 2012 Max Brister <max@2bass.com> This file is part of Octave. Octave is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version. Octave is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Octave; see the file COPYING. If not, see <http://www.gnu.org/licenses/>. */ #define __STDC_LIMIT_MACROS #define __STDC_CONSTANT_MACROS #ifdef HAVE_CONFIG_H #include <config.h> #endif #include "pt-jit.h" #include <typeinfo> #include <llvm/LLVMContext.h> #include <llvm/Module.h> #include <llvm/Function.h> #include <llvm/BasicBlock.h> #include <llvm/Support/IRBuilder.h> #include <llvm/ExecutionEngine/ExecutionEngine.h> #include <llvm/ExecutionEngine/JIT.h> #include <llvm/PassManager.h> #include <llvm/Analysis/Verifier.h> #include <llvm/Analysis/Passes.h> #include <llvm/Target/TargetData.h> #include <llvm/Transforms/Scalar.h> #include <llvm/Support/TargetSelect.h> #include <llvm/Support/raw_os_ostream.h> #include "ov-fcn-handle.h" #include "ov-usr-fcn.h" #include "pt-all.h" using namespace llvm; //FIXME: Move into tree_jit static IRBuilder<> builder (getGlobalContext ()); extern "C" void octave_print_double (const char *name, double value) { // FIXME: We should avoid allocating a new octave_scalar each time octave_value ov (value); ov.print_with_name (octave_stdout, name); } tree_jit::tree_jit (void) : context (getGlobalContext ()), engine (0) { InitializeNativeTarget (); InitializeNativeTargetAsmPrinter (); module = new Module ("octave", context); } tree_jit::~tree_jit (void) { delete module; } bool tree_jit::execute (tree& tee) { if (!engine) { engine = ExecutionEngine::createJIT (module); // initialize pass manager pass_manager = new FunctionPassManager (module); pass_manager->add (new TargetData(*engine->getTargetData ())); pass_manager->add (createBasicAliasAnalysisPass ()); pass_manager->add (createPromoteMemoryToRegisterPass ()); pass_manager->add (createInstructionCombiningPass ()); pass_manager->add (createReassociatePass ()); pass_manager->add (createGVNPass ()); pass_manager->add (createCFGSimplificationPass ()); pass_manager->doInitialization (); // create external functions Type *vtype = Type::getVoidTy (context); std::vector<Type*> pd_args (2); pd_args[0] = Type::getInt8PtrTy (context); pd_args[1] = Type::getDoubleTy (context); FunctionType *print_double_ty = FunctionType::get (vtype, pd_args, false); print_double = Function::Create (print_double_ty, Function::ExternalLinkage, "octave_print_double", module); engine->addGlobalMapping (print_double, reinterpret_cast<void*>(&octave_print_double)); } if (!engine) // sometimes this fails during early initialization return false; // find function function_info *finfo; finfo_map_iterator iter = compiled_functions.find (&tee); if (iter == compiled_functions.end ()) finfo = compile (tee); else finfo = iter->second; return finfo->execute (); } tree_jit::function_info* tree_jit::compile (tree& tee) { value_stack.clear (); variables.clear (); // setup function std::vector<Type*> args (2); args[0] = Type::getInt1PtrTy (context); args[1] = Type::getDoublePtrTy (context); FunctionType *ft = FunctionType::get (Type::getVoidTy (context), args, false); Function *compiling = Function::Create (ft, Function::ExternalLinkage, "test_fn", module); entry_block = BasicBlock::Create (context, "entry", compiling); BasicBlock *body = BasicBlock::Create (context, "body", compiling); builder.SetInsertPoint (body); // convert tree to LLVM IR try { tee.accept (*this); } catch (const jit_fail_exception&) { //FIXME: cleanup return compiled_functions[&tee] = new function_info (); } // copy input arguments builder.SetInsertPoint (entry_block); Function::arg_iterator arg_iter = compiling->arg_begin (); Value *arg_defined = arg_iter; Value *arg_value = ++arg_iter; arg_defined->setName ("arg_defined"); arg_value->setName ("arg_value"); size_t idx = 0; std::vector<std::string> arg_names; std::vector<bool> arg_used; for (var_map_iterator iter = variables.begin (); iter != variables.end (); ++iter, ++idx) { arg_names.push_back (iter->first); arg_used.push_back (iter->second.use); Value *gep_defined = builder.CreateConstInBoundsGEP1_32 (arg_defined, idx); Value *defined = builder.CreateLoad (gep_defined); builder.CreateStore (defined, iter->second.defined); Value *gep_value = builder.CreateConstInBoundsGEP1_32 (arg_value, idx); Value *value = builder.CreateLoad (gep_value); builder.CreateStore (value, iter->second.value); } builder.CreateBr (body); // copy output arguments BasicBlock *cleanup = BasicBlock::Create (context, "cleanup", compiling); builder.SetInsertPoint (body); builder.CreateBr (cleanup); builder.SetInsertPoint (cleanup); idx = 0; for (var_map_iterator iter = variables.begin (); iter != variables.end (); ++iter, ++idx) { Value *gep_defined = builder.CreateConstInBoundsGEP1_32 (arg_defined, idx); Value *defined = builder.CreateLoad (iter->second.defined); builder.CreateStore (defined, gep_defined); Value *gep_value = builder.CreateConstInBoundsGEP1_32 (arg_value, idx); Value *value = builder.CreateLoad (iter->second.value, iter->first); builder.CreateStore (value, gep_value); } builder.CreateRetVoid (); // print what we compiled (for debugging) // we leave this in for now, as other people might want to view the ir created // should be removed eventually though const bool debug_print_ir = false; if (debug_print_ir) { raw_os_ostream os (std::cout); std::cout << "Compiling --------------------\n"; tree_print_code tpc (std::cout); std::cout << typeid (tee).name () << std::endl; tee.accept (tpc); std::cout << "\n--------------------\n"; std::cout << "llvm_ir\n"; compiling->print (os); std::cout << "--------------------\n"; } // compile code verifyFunction (*compiling); pass_manager->run (*compiling); if (debug_print_ir) { raw_os_ostream os (std::cout); std::cout << "optimized llvm_ir\n"; compiling->print (os); std::cout << "--------------------\n"; } jit_function fun = reinterpret_cast<jit_function> (engine->getPointerToFunction (compiling)); return compiled_functions[&tee] = new function_info (fun, arg_names, arg_used); } tree_jit::variable_info tree_jit::find (const std::string &name, bool use) { var_map_iterator iter = variables.find (name); if (iter == variables.end ()) { // we currently just assume everything is a double Type *dbl = Type::getDoubleTy (context); Type *bol = Type::getInt1Ty (context); IRBuilder<> tmpB (entry_block, entry_block->begin ()); variable_info vinfo; vinfo.defined = tmpB.CreateAlloca (bol, 0); vinfo.value = tmpB.CreateAlloca (dbl, 0, name); vinfo.use = use; variables[name] = vinfo; return vinfo; } else { iter->second.use = iter->second.use || use; return iter->second; } } void tree_jit::do_assign (variable_info vinfo, llvm::Value *value) { // create assign expression Value *result = builder.CreateStore (value, vinfo.value); value_stack.push_back (result); // update defined for lhs Type *btype = Type::getInt1Ty (context); Value *btrue = ConstantInt::get (btype, APInt (1, 1)); builder.CreateStore (btrue, vinfo.defined); } void tree_jit::emit_print (const std::string& vname, llvm::Value *value) { Value *pname = builder.CreateGlobalStringPtr (vname); builder.CreateCall2 (print_double, pname, value); } void tree_jit::visit_anon_fcn_handle (tree_anon_fcn_handle&) { fail (); } void tree_jit::visit_argument_list (tree_argument_list&) { fail (); } void tree_jit::visit_binary_expression (tree_binary_expression& be) { tree_expression *lhs = be.lhs (); tree_expression *rhs = be.rhs (); if (lhs && rhs) { lhs->accept (*this); rhs->accept (*this); Value *lhsv = value_stack.back (); value_stack.pop_back (); Value *rhsv = value_stack.back (); value_stack.pop_back (); Value *result; switch (be.op_type ()) { case octave_value::op_add: result = builder.CreateFAdd (lhsv, rhsv); break; case octave_value::op_sub: result = builder.CreateFSub (lhsv, rhsv); break; case octave_value::op_mul: result = builder.CreateFMul (lhsv, rhsv); break; case octave_value::op_div: result = builder.CreateFDiv (lhsv, rhsv); break; default: fail (); } value_stack.push_back (result); } else fail (); } void tree_jit::visit_break_command (tree_break_command&) { fail (); } void tree_jit::visit_colon_expression (tree_colon_expression&) { fail (); } void tree_jit::visit_continue_command (tree_continue_command&) { fail (); } void tree_jit::visit_global_command (tree_global_command&) { fail (); } void tree_jit::visit_persistent_command (tree_persistent_command&) { fail (); } void tree_jit::visit_decl_elt (tree_decl_elt&) { fail (); } void tree_jit::visit_decl_init_list (tree_decl_init_list&) { fail (); } void tree_jit::visit_simple_for_command (tree_simple_for_command&) { fail (); } void tree_jit::visit_complex_for_command (tree_complex_for_command&) { fail (); } void tree_jit::visit_octave_user_script (octave_user_script&) { fail (); } void tree_jit::visit_octave_user_function (octave_user_function&) { fail (); } void tree_jit::visit_octave_user_function_header (octave_user_function&) { fail (); } void tree_jit::visit_octave_user_function_trailer (octave_user_function&) { fail (); } void tree_jit::visit_function_def (tree_function_def&) { fail (); } void tree_jit::visit_identifier (tree_identifier& ti) { octave_value ov = ti.do_lookup (); if (ov.is_function ()) fail (); std::string name = ti.name (); variable_info vinfo = find (ti.name (), true); // TODO check defined Value *load_value = builder.CreateLoad (vinfo.value, name); value_stack.push_back (load_value); } void tree_jit::visit_if_clause (tree_if_clause&) { fail (); } void tree_jit::visit_if_command (tree_if_command&) { fail (); } void tree_jit::visit_if_command_list (tree_if_command_list&) { fail (); } void tree_jit::visit_index_expression (tree_index_expression&) { fail (); } void tree_jit::visit_matrix (tree_matrix&) { fail (); } void tree_jit::visit_cell (tree_cell&) { fail (); } void tree_jit::visit_multi_assignment (tree_multi_assignment&) { fail (); } void tree_jit::visit_no_op_command (tree_no_op_command&) { fail (); } void tree_jit::visit_constant (tree_constant& tc) { octave_value v = tc.rvalue1 (); if (v.is_real_scalar () && v.is_double_type ()) { double dv = v.double_value (); Value *lv = ConstantFP::get (context, APFloat (dv)); value_stack.push_back (lv); } else fail (); } void tree_jit::visit_fcn_handle (tree_fcn_handle&) { fail (); } void tree_jit::visit_parameter_list (tree_parameter_list&) { fail (); } void tree_jit::visit_postfix_expression (tree_postfix_expression&) { fail (); } void tree_jit::visit_prefix_expression (tree_prefix_expression&) { fail (); } void tree_jit::visit_return_command (tree_return_command&) { fail (); } void tree_jit::visit_return_list (tree_return_list&) { fail (); } void tree_jit::visit_simple_assignment (tree_simple_assignment& tsa) { // only support an identifier as lhs tree_identifier *lhs = dynamic_cast<tree_identifier*> (tsa.left_hand_side ()); if (!lhs) fail (); variable_info lhsv = find (lhs->name (), false); // resolve rhs as normal tree_expression *rhs = tsa.right_hand_side (); rhs->accept (*this); Value *rhsv = value_stack.back (); value_stack.pop_back (); do_assign (lhsv, rhsv); if (tsa.print_result ()) emit_print (lhs->name (), rhsv); } void tree_jit::visit_statement (tree_statement& stmt) { tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); if (cmd) cmd->accept (*this); else { // TODO deal with printing // stolen from tree_evaluator::visit_statement bool do_bind_ans = false; if (expr->is_identifier ()) { tree_identifier *id = dynamic_cast<tree_identifier *> (expr); do_bind_ans = (! id->is_variable ()); } else do_bind_ans = (! expr->is_assignment_expression ()); expr->accept (*this); if (do_bind_ans) { Value *rhs = value_stack.back (); value_stack.pop_back (); variable_info ans = find ("ans", false); do_assign (ans, rhs); } else if (expr->is_identifier () && expr->print_result ()) { // FIXME: ugly hack, we need to come up with a way to pass // nargout to visit_identifier emit_print (expr->name (), value_stack.back ()); } value_stack.pop_back (); } } void tree_jit::visit_statement_list (tree_statement_list&) { fail (); } void tree_jit::visit_switch_case (tree_switch_case&) { fail (); } void tree_jit::visit_switch_case_list (tree_switch_case_list&) { fail (); } void tree_jit::visit_switch_command (tree_switch_command&) { fail (); } void tree_jit::visit_try_catch_command (tree_try_catch_command&) { fail (); } void tree_jit::visit_unwind_protect_command (tree_unwind_protect_command&) { fail (); } void tree_jit::visit_while_command (tree_while_command&) { fail (); } void tree_jit::visit_do_until_command (tree_do_until_command&) { fail (); } void tree_jit::fail (void) { throw jit_fail_exception (); } tree_jit::function_info::function_info (void) : function (0) {} tree_jit::function_info::function_info (jit_function fun, const std::vector<std::string>& args, const std::vector<bool>& arg_used) : function (fun), arguments (args), argument_used (arg_used) {} bool tree_jit::function_info::execute () { if (! function) return false; // FIXME: we are doing hash lookups every time, this has got to be slow unwind_protect up; bool *args_defined = new bool[arguments.size ()]; // vector<bool> sucks up.add_delete (args_defined); std::vector<double> args_values (arguments.size ()); for (size_t i = 0; i < arguments.size (); ++i) { octave_value ov = symbol_table::varval (arguments[i]); if (argument_used[i]) { if (! (ov.is_double_type () && ov.is_real_scalar ())) return false; args_defined[i] = ov.is_defined (); args_values[i] = ov.double_value (); } else args_defined[i] = false; } function (args_defined, &args_values[0]); for (size_t i = 0; i < arguments.size (); ++i) if (args_defined[i]) symbol_table::varref (arguments[i]) = octave_value (args_values[i]); return true; }