Mercurial > octave-nkf
diff src/pt-jit.h @ 14903:54ea692b8ab5
Reworking JIT implementation
src/TEMPLATE-INST/Array-jit.cc: New file.
src/TEMPLATE-INST/module.mk: Add Array-jit.cc.
src/ov-base.h (octave_base_value::grab,
octave_base_value::release): New functions.
src/pt-jit.cc: Rewrite.
src/pt-jit.h: Rewrite.
author | Max Brister <max@2bass.com> |
---|---|
date | Sat, 12 May 2012 19:24:32 -0600 |
parents | 516b4a15b775 |
children | 3f81e8b42955 |
line wrap: on
line diff
--- a/src/pt-jit.h Wed May 09 12:53:41 2012 -0600 +++ b/src/pt-jit.h Sat May 12 19:24:32 2012 -0600 @@ -23,31 +23,247 @@ #if !defined (octave_tree_jit_h) #define octave_tree_jit_h 1 +#include <list> #include <map> +#include <set> #include <stdexcept> #include <vector> +#include "Array.h" #include "pt-walk.h" -class jit_fail_exception : public std::exception {}; +// -------------------- Current status -------------------- +// Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized. +// However, there is no warning emitted on divide by 0. For example, +// a = 5; +// b = a * 5 + a; +// +// For other types all binary operations are compiled but not optimized. For +// example, +// a = [1 2 3] +// b = a + a; +// will compile to do_binary_op (a, a). +// --------------------------------------------------------- -// LLVM forward declares + +// we don't want to include llvm headers here, as they require __STDC_LIMIT_MACROS +// and __STDC_CONSTANT_MACROS be defined in the entire compilation unit namespace llvm { class Value; class Module; class FunctionPassManager; + class PassManager; class ExecutionEngine; class Function; class BasicBlock; class LLVMContext; + class Type; + class GenericValue; } +class octave_base_value; +class octave_value; class tree; +// thrown when we should give up on JIT and interpret +class jit_fail_exception : public std::exception {}; + +// Used to keep track of estimated (infered) types during JIT. This is a +// hierarchical type system which includes both concrete and abstract types. +// +// Current, we only support any and scalar types. If we can't figure out what +// type a variable is, we assign it the any type. This allows us to generate +// code even for the case of poor type inference. +class +OCTINTERP_API +jit_type +{ +public: + jit_type (const std::string& n, bool fi, jit_type *mparent, llvm::Type *lt, + int tid) : + mname (n), finit (fi), p (mparent), llvm_type (lt), id (tid) + {} + + // a user readable type name + const std::string& name (void) const { return mname; } + + // do we need to initialize variables of this type, even if they are not + // input arguments? + bool force_init (void) const { return finit; } + + // a unique id for the type + int type_id (void) const { return id; } + + // An abstract base type, may be null + jit_type *parent (void) const { return p; } + + // convert to an llvm type + llvm::Type *to_llvm (void) const { return llvm_type; } + + // how this type gets passed as a function argument + llvm::Type *to_llvm_arg (void) const; +private: + std::string mname; + bool finit; + jit_type *p; + llvm::Type *llvm_type; + int id; + int depth; +}; + + +// Keeps track of overloads for a builtin function. Used for both type inference +// and code generation. +class +jit_function +{ +public: + struct overload + { + overload (void) : function (0), can_error (true), result (0) {} + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) : + function (f), can_error (e), result (r), arguments (1) + { + arguments[0] = arg0; + } + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) : function (f), can_error (e), result (r), + arguments (2) + { + arguments[0] = arg0; + arguments[1] = arg1; + } + + llvm::Function *function; + bool can_error; + jit_type *result; + std::vector<jit_type*> arguments; + }; + + void add_overload (const overload& func) + { + add_overload (func, func.arguments); + } + + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) + { + overload ol (f, e, r, arg0, arg1); + add_overload (ol); + } + + void add_overload (const overload& func, + const std::vector<jit_type*>& args); + + const overload& get_overload (const std::vector<jit_type *>& types) const; + + const overload& get_overload (jit_type *arg0) const + { + std::vector<jit_type *> types (1); + types[0] = arg0; + return get_overload (types); + } + + const overload& get_overload (jit_type *arg0, jit_type *arg1) const + { + std::vector<jit_type *> types (2); + types[0] = arg0; + types[1] = arg1; + return get_overload (types); + } + + jit_type *get_result (const std::vector<jit_type *>& types) const + { + const overload& temp = get_overload (types); + return temp.result; + } + + jit_type *get_result (jit_type *arg0, jit_type *arg1) const + { + const overload& temp = get_overload (arg0, arg1); + return temp.result; + } +private: + Array<octave_idx_type> to_idx (const std::vector<jit_type*>& types) const; + + std::vector<Array<overload> > overloads; +}; + +// Get information and manipulate jit types. +class +OCTINTERP_API +jit_typeinfo +{ +public: + jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e, llvm::Type *ov); + + jit_type *get_any (void) const { return any; } + + jit_type *get_scalar (void) const { return scalar; } + + jit_type *type_of (const octave_value& ov) const; + + const jit_function& binary_op (int op) const; + + const jit_function::overload& binary_op_overload (int op, jit_type *lhs, + jit_type *rhs) const + { + const jit_function& jf = binary_op (op); + return jf.get_overload (lhs, rhs); + } + + jit_type *binary_op_result (int op, jit_type *lhs, jit_type *rhs) const + { + const jit_function::overload& ol = binary_op_overload (op, lhs, rhs); + return ol.result; + } + + const jit_function::overload& assign_op (jit_type *lhs, jit_type *rhs) const; + + const jit_function::overload& print_value (jit_type *to_print) const; + + // FIXME: generic creation should probably be handled seperatly + void to_generic (jit_type *type, llvm::GenericValue& gv); + void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov); + + octave_value to_octave_value (jit_type *type, llvm::GenericValue& gv); + + void reset_generic (size_t nargs); +private: + jit_type *new_type (const std::string& name, bool force_init, + jit_type *parent, llvm::Type *llvm_type); + + void add_print (jit_type *ty, void *call); + + void add_binary_op (jit_type *ty, int op, int llvm_op); + + llvm::Module *module; + llvm::ExecutionEngine *engine; + int next_id; + + llvm::Type *ov_t; + + std::vector<jit_type*> id_to_type; + jit_type *any; + jit_type *scalar; + + std::vector<jit_function> binary_ops; + jit_function assign_fn; + jit_function print_fn; + + size_t scalar_out_idx; + std::vector<double> scalar_out; + + size_t ov_out_idx; + std::vector<octave_base_value*> ov_out; +}; + class OCTINTERP_API -tree_jit : private tree_walker +tree_jit { public: tree_jit (void); @@ -56,146 +272,265 @@ bool execute (tree& tee); private: - typedef void (*jit_function)(bool*, double*); + typedef std::map<std::string, jit_type *> type_map; + + class + type_infer : public tree_walker + { + public: + type_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false), + rvalue_type (0) + {} + + const std::set<std::string>& get_argin () const { return argin; } + + const type_map& get_types () const { return types; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); + private: + void handle_identifier (const std::string& name, octave_value v); + + jit_typeinfo *tinfo; + + bool is_lvalue; + jit_type *rvalue_type; + + type_map types; + std::set<std::string> argin; + + std::vector<jit_type *> type_stack; + }; + + class + code_generator : public tree_walker + { + public: + code_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, + const std::set<std::string>& argin, + const type_map& infered_types); + + llvm::Function *get_function () const { return function; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); + private: + typedef std::pair<jit_type *, llvm::Value *> value; + + void emit_print (const std::string& name, const value& v); + + void push_value (jit_type *type, llvm::Value *v) + { + value_stack.push_back (value (type, v)); + } + + jit_typeinfo *tinfo; + llvm::Function *function; + + bool is_lvalue; + std::map<std::string, value> variables; + std::vector<value> value_stack; + }; class function_info { public: - function_info (void); - function_info (jit_function fn, const std::vector<std::string>& args, - const std::vector<bool>& args_used); + function_info (tree_jit& tjit, tree& tee); - bool execute (); + bool execute () const; + + bool match () const; private: - jit_function function; - std::vector<std::string> arguments; - - // is the argument used? or is it just declared? - std::vector<bool> argument_used; - }; - - struct variable_info - { - llvm::Value *defined; - llvm::Value *value; - bool use; + jit_typeinfo *tinfo; + llvm::ExecutionEngine *engine; + std::set<std::string> argin; + type_map types; + llvm::Function *function; }; - function_info *compile (tree& tee); - - variable_info find (const std::string &name, bool use); - - void do_assign (variable_info vinfo, llvm::Value *value); - - void emit_print (const std::string& vname, llvm::Value *value); - - // tree_walker - void visit_anon_fcn_handle (tree_anon_fcn_handle&); - - void visit_argument_list (tree_argument_list&); - - void visit_binary_expression (tree_binary_expression&); - - void visit_break_command (tree_break_command&); - - void visit_colon_expression (tree_colon_expression&); - - void visit_continue_command (tree_continue_command&); - - void visit_global_command (tree_global_command&); - - void visit_persistent_command (tree_persistent_command&); - - void visit_decl_elt (tree_decl_elt&); - - void visit_decl_init_list (tree_decl_init_list&); - - void visit_simple_for_command (tree_simple_for_command&); - - void visit_complex_for_command (tree_complex_for_command&); - - void visit_octave_user_script (octave_user_script&); - - void visit_octave_user_function (octave_user_function&); - - void visit_octave_user_function_header (octave_user_function&); - - void visit_octave_user_function_trailer (octave_user_function&); - - void visit_function_def (tree_function_def&); - - void visit_identifier (tree_identifier&); - - void visit_if_clause (tree_if_clause&); - - void visit_if_command (tree_if_command&); - - void visit_if_command_list (tree_if_command_list&); - - void visit_index_expression (tree_index_expression&); + typedef std::list<function_info *> function_list; + typedef std::map<tree *, function_list> compiled_map; - void visit_matrix (tree_matrix&); - - void visit_cell (tree_cell&); - - void visit_multi_assignment (tree_multi_assignment&); - - void visit_no_op_command (tree_no_op_command&); - - void visit_constant (tree_constant&); - - void visit_fcn_handle (tree_fcn_handle&); - - void visit_parameter_list (tree_parameter_list&); - - void visit_postfix_expression (tree_postfix_expression&); - - void visit_prefix_expression (tree_prefix_expression&); - - void visit_return_command (tree_return_command&); - - void visit_return_list (tree_return_list&); - - void visit_simple_assignment (tree_simple_assignment&); - - void visit_statement (tree_statement&); - - void visit_statement_list (tree_statement_list&); - - void visit_switch_case (tree_switch_case&); - - void visit_switch_case_list (tree_switch_case_list&); - - void visit_switch_command (tree_switch_command&); - - void visit_try_catch_command (tree_try_catch_command&); - - void do_unwind_protect_cleanup_code (tree_statement_list *list); - - void visit_unwind_protect_command (tree_unwind_protect_command&); - - void visit_while_command (tree_while_command&); - - void visit_do_until_command (tree_do_until_command&); - - void fail (void); - - typedef std::map<std::string, variable_info> var_map; - typedef var_map::iterator var_map_iterator; - typedef std::map<tree*, function_info*> finfo_map; - typedef finfo_map::iterator finfo_map_iterator; - - std::vector<llvm::Value*> value_stack; - var_map variables; - finfo_map compiled_functions; + static void fail (void) + { + throw jit_fail_exception (); + } llvm::LLVMContext &context; llvm::Module *module; + llvm::PassManager *module_pass_manager; llvm::FunctionPassManager *pass_manager; llvm::ExecutionEngine *engine; - llvm::BasicBlock *entry_block; - llvm::Function *print_double; + jit_typeinfo *tinfo; + + compiled_map compiled; }; #endif