Mercurial > octave-nkf
diff src/pt-jit.h @ 14928:39d52aa37a08
Use standard SSA construction algorithm, and support break/continue
author | Max Brister <max@2bass.com> |
---|---|
date | Fri, 01 Jun 2012 19:08:43 -0500 |
parents | aebd296a15c4 |
children | 1f914446157d |
line wrap: on
line diff
--- a/src/pt-jit.h Wed May 30 09:36:38 2012 -0500 +++ b/src/pt-jit.h Fri Jun 01 19:08:43 2012 -0500 @@ -28,6 +28,7 @@ #include <set> #include <stdexcept> #include <vector> +#include <stack> #include "Array.h" #include "Range.h" @@ -48,6 +49,7 @@ // // For loops are compiled again! // if, elseif, and else statements compile again! +// break and continue now work! // Additionally, make check passes using jit. // // The octave low level IR is a linear IR, it works by converting everything to @@ -58,11 +60,12 @@ // // // TODO: -// 1. Support error cases -// 2. Support break/continue -// 3. Fix memory leaks in JIT -// 4. Cleanup/documentation -// 5. ... +// 1. Rename symbol_table::symbol_record_ref -> symbol_table::symbol_reference +// 2. Support some simple matrix case (and cleanup Octave low level IR) +// 3. Support error cases +// 4. Fix memory leaks in JIT +// 5. Cleanup/documentation +// 6. ... // --------------------------------------------------------- @@ -150,12 +153,7 @@ }; // seperate print function to allow easy printing if type is null -static std::ostream& jit_print (std::ostream& os, jit_type *atype) -{ - if (! atype) - return os << "null"; - return os << atype->name (); -} +std::ostream& jit_print (std::ostream& os, jit_type *atype); // Keeps track of overloads for a builtin function. Used for both type inference // and code generation. @@ -496,7 +494,8 @@ JIT_METH(call); \ JIT_METH(extract_argument); \ JIT_METH(store_argument); \ - JIT_METH(phi) + JIT_METH(phi); \ + JIT_METH(variable) #define JIT_VISIT_IR_CLASSES \ JIT_VISIT_IR_NOTEMPLATE; \ @@ -515,6 +514,9 @@ virtual ~jit_value (void); + // replace all uses with + void replace_with (jit_value *value); + jit_type *type (void) const { return ty; } llvm::Type *type_llvm (void) const @@ -540,15 +542,21 @@ return ss.str (); } - virtual std::ostream& print (std::ostream& os, size_t indent = 0) = 0; + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const = 0; - virtual std::ostream& short_print (std::ostream& os) + virtual std::ostream& short_print (std::ostream& os) const { return print (os); } virtual void accept (jit_ir_walker& walker) = 0; + bool has_llvm (void) const + { + return llvm_value; + } + llvm::Value *to_llvm (void) const { + assert (llvm_value); return llvm_value; } @@ -557,10 +565,10 @@ llvm_value = compiled; } protected: - std::ostream& print_indent (std::ostream& os, size_t indent) + std::ostream& print_indent (std::ostream& os, size_t indent) const { - for (size_t i = 0; i < indent; ++i) - os << "\t"; + for (size_t i = 0; i < indent * 8; ++i) + os << " "; return os; } @@ -571,54 +579,7 @@ size_t myuse_count; }; -// defnie accept methods for subclasses -#define JIT_VALUE_ACCEPT(clname) \ - virtual void accept (jit_ir_walker& walker); - -template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T = T, - bool QUOTE=false> -class -jit_const : public jit_value -{ -public: - typedef PASS_T pass_t; - - jit_const (PASS_T avalue) : mvalue (avalue) - { - stash_type (EXTRACT_T ()); - } - - PASS_T value (void) const { return mvalue; } - - virtual std::ostream& print (std::ostream& os, size_t indent) - { - print_indent (os, indent) << type_name () << ": "; - if (QUOTE) - os << "\""; - os << mvalue; - if (QUOTE) - os << "\""; - return os; - } - - JIT_VALUE_ACCEPT (jit_const); -private: - T mvalue; -}; - -typedef jit_const<double, jit_typeinfo::get_scalar> jit_const_scalar; -typedef jit_const<octave_idx_type, jit_typeinfo::get_index> jit_const_index; - -typedef jit_const<std::string, jit_typeinfo::get_string, const std::string&, true> -jit_const_string; -typedef jit_const<jit_range, jit_typeinfo::get_range, const jit_range&> -jit_const_range; - -#define JIT_VISIT_IR_CONST \ - JIT_METH(const_scalar); \ - JIT_METH(const_index); \ - JIT_METH(const_string); \ - JIT_METH(const_range) +std::ostream& operator<< (std::ostream& os, const jit_value& value); class jit_instruction; class jit_block; @@ -705,6 +666,8 @@ size_t mindex; }; +class jit_variable; + class jit_instruction : public jit_value { @@ -764,7 +727,6 @@ jit_type *argument_type (size_t i) const { - assert (argument (i)); return argument (i)->type (); } @@ -808,19 +770,11 @@ virtual bool infer (void) { return false; } - virtual std::ostream& short_print (std::ostream& os) - { - if (mtag.empty ()) - jit_print (os, type ()) << ": #" << id; - else - jit_print (os, type ()) << ": " << mtag << "." << id; + void push_variable (void); - return os; - } + void pop_variable (void); - const std::string& tag (void) const { return mtag; } - - void stash_tag (const std::string& atag) { mtag = atag; } + virtual std::ostream& short_print (std::ostream& os) const; jit_block *parent (void) const { return mparent; } @@ -831,6 +785,10 @@ assert (! mparent); mparent = aparent; } + + jit_variable *tag (void) const; + + void stash_tag (jit_variable *atag); protected: std::vector<jit_type *> already_infered; private: @@ -845,13 +803,65 @@ std::vector<jit_use> arguments; - std::string mtag; + jit_use mtag; + size_t id; jit_block *mparent; }; +// defnie accept methods for subclasses +#define JIT_VALUE_ACCEPT(clname) \ + virtual void accept (jit_ir_walker& walker); + +template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T = T, + bool QUOTE=false> +class +jit_const : public jit_instruction +{ +public: + typedef PASS_T pass_t; + + jit_const (PASS_T avalue) : mvalue (avalue) + { + stash_type (EXTRACT_T ()); + } + + PASS_T value (void) const { return mvalue; } + + virtual std::ostream& print (std::ostream& os, size_t indent) const + { + print_indent (os, indent); + short_print (os) << " = "; + if (QUOTE) + os << "\""; + os << mvalue; + if (QUOTE) + os << "\""; + return os; + } + + JIT_VALUE_ACCEPT (jit_const); +private: + T mvalue; +}; + +typedef jit_const<double, jit_typeinfo::get_scalar> jit_const_scalar; +typedef jit_const<octave_idx_type, jit_typeinfo::get_index> jit_const_index; + +typedef jit_const<std::string, jit_typeinfo::get_string, const std::string&, true> +jit_const_string; +typedef jit_const<jit_range, jit_typeinfo::get_range, const jit_range&> +jit_const_range; + +#define JIT_VISIT_IR_CONST \ + JIT_METH(const_scalar); \ + JIT_METH(const_index); \ + JIT_METH(const_string); \ + JIT_METH(const_range) + class jit_terminator; class jit_phi; +class jit_convert; class jit_block : public jit_value @@ -861,7 +871,11 @@ typedef instruction_list::iterator iterator; typedef instruction_list::const_iterator const_iterator; - jit_block (const std::string& aname) : mname (aname) + typedef std::set<jit_block *> df_set; + typedef df_set::const_iterator df_iterator; + + jit_block (const std::string& aname) : mvisit_count (0), mid (NO_ID), idom (0), + mname (aname) {} const std::string& name (void) const { return mname; } @@ -870,6 +884,15 @@ jit_instruction *append (jit_instruction *instr); + jit_instruction *insert_before (iterator loc, jit_instruction *instr); + + jit_instruction *insert_after (iterator loc, jit_instruction *instr); + + void remove (jit_block::iterator iter) + { + instructions.erase (iter); + } + jit_terminator *terminator (void) const; jit_block *pred (size_t idx) const; @@ -879,9 +902,7 @@ return pred (idx)->terminator (); } - llvm::Value *pred_terminator_llvm (size_t idx) const; - - std::ostream& print_pred (std::ostream& os, size_t idx) + std::ostream& print_pred (std::ostream& os, size_t idx) const { return pred (idx)->short_print (os); } @@ -907,6 +928,8 @@ size_t pred_count (void) const { return use_count (); } + jit_block *succ (size_t i) const; + size_t succ_count (void) const; iterator begin (void) { return instructions.begin (); } @@ -915,15 +938,75 @@ iterator end (void) { return instructions.end (); } - const_iterator end (void) const { return instructions.begin (); } + const_iterator end (void) const { return instructions.end (); } + + iterator phi_begin (void); + + iterator phi_end (void); + + iterator nonphi_begin (void); + + // must label before id is valid + size_t id (void) const { return mid; } + + // dominance frontier + const df_set& df (void) const { return mdf; } + + df_iterator df_begin (void) const { return mdf.begin (); } + + df_iterator df_end (void) const { return mdf.end (); } + + // label with a RPO walk + void label (void) + { + size_t number = 0; + label (mvisit_count, number); + } + + void label (size_t visit_count, size_t& number) + { + if (mvisit_count > visit_count) + return; + ++mvisit_count; + + for (size_t i = 0; i < pred_count (); ++i) + pred (i)->label (visit_count, number); - // search for the phi function with the given tag_name, if no function - // exists then null is returned - jit_phi *search_phi (const std::string& tag_name); + mid = number; + ++number; + } + + // See for idom computation algorithm + // Cooper, Keith D.; Harvey, Timothy J; and Kennedy, Ken (2001). + // "A Simple, Fast Dominance Algorithm" + void compute_idom (jit_block *final) + { + bool changed; + idom = this; + do + changed = final->update_idom (mvisit_count); + while (changed); + } - virtual std::ostream& print (std::ostream& os, size_t indent) + // compute dominance frontier + void compute_df (void) + { + compute_df (mvisit_count); + } + + void create_dom_tree (void) { - print_indent (os, indent) << mname << ":\tpred = "; + create_dom_tree (mvisit_count); + } + + void construct_ssa (jit_convert& convert) + { + do_construct_ssa (convert, mvisit_count); + } + + virtual std::ostream& print (std::ostream& os, size_t indent) const + { + print_indent (os, indent) << mname << ": %pred = "; for (size_t i = 0; i < pred_count (); ++i) { print_pred (os, i); @@ -932,7 +1015,7 @@ } os << std::endl; - for (iterator iter = begin (); iter != end (); ++iter) + for (const_iterator iter = begin (); iter != end (); ++iter) { jit_instruction *instr = *iter; instr->print (os, indent + 1) << std::endl; @@ -940,7 +1023,10 @@ return os; } - virtual std::ostream& short_print (std::ostream& os) + // print dominator infomration + std::ostream& print_dom (std::ostream& os) const; + + virtual std::ostream& short_print (std::ostream& os) const { return os << mname; } @@ -949,18 +1035,93 @@ JIT_VALUE_ACCEPT (block) private: + void compute_df (size_t visit_count); + + bool update_idom (size_t visit_count); + + void finish_phi (jit_block *pred); + + void do_construct_ssa (jit_convert& convert, size_t visit_count); + + void create_dom_tree (size_t visit_count); + + jit_block *idom_intersect (jit_block *b); + + static const size_t NO_ID = static_cast<size_t> (-1); + size_t mvisit_count; + size_t mid; + jit_block *idom; + df_set mdf; + std::vector<jit_block *> dom_succ; std::string mname; instruction_list instructions; mutable std::vector<llvm::BasicBlock *> mpred_llvm; }; + + +// A non-ssa variable +class +jit_variable : public jit_value +{ +public: + jit_variable (const std::string& aname) : mname (aname) {} + + const std::string &name (void) const { return mname; } + + // manipulate the value_stack, for use during SSA construction. The top of the + // value stack represents the current value for this variable + bool has_top (void) const + { + return ! value_stack.empty (); + } + + jit_value *top (void) const + { + return value_stack.top (); + } + + void push (jit_value *v) + { + value_stack.push (v); + } + + void pop (void) + { + value_stack.pop (); + } + + // blocks in which we are used + void use_blocks (jit_block::df_set& result) + { + jit_use *use = first_use (); + while (use) + { + result.insert (use->user_parent ()); + use = use->next (); + } + } + + virtual std::ostream& print (std::ostream& os, size_t indent) const + { + return print_indent (os, indent) << mname; + } + + JIT_VALUE_ACCEPT (variable) +private: + std::string mname; + std::stack<jit_value *> value_stack; +}; + class jit_phi : public jit_instruction { public: - jit_phi (size_t npred, jit_value *adefault = 0) - : jit_instruction (npred, adefault) - {} + jit_phi (jit_variable *avariable, size_t npred) + : jit_instruction (npred) + { + stash_tag (avariable); + } virtual bool infer (void) { @@ -977,13 +1138,13 @@ return false; } - virtual std::ostream& print (std::ostream& os, size_t indent) + virtual std::ostream& print (std::ostream& os, size_t indent) const { std::stringstream ss; print_indent (ss, indent); short_print (ss) << " phi "; std::string ss_str = ss.str (); - std::string indent_str (ss_str.size () + 7, ' '); + std::string indent_str (ss_str.size (), ' '); os << ss_str; jit_block *pblock = parent (); @@ -1028,7 +1189,7 @@ return pllvm == spred_llvm ? succ_llvm : spred_llvm; } - std::ostream& print_sucessor (std::ostream& os, size_t idx = 0) + std::ostream& print_sucessor (std::ostream& os, size_t idx = 0) const { return sucessor (idx)->short_print (os); } @@ -1050,7 +1211,7 @@ size_t sucessor_count (void) const { return 1; } - virtual std::ostream& print (std::ostream& os, size_t indent) + virtual std::ostream& print (std::ostream& os, size_t indent) const { print_indent (os, indent) << "break: "; return print_sucessor (os); @@ -1068,7 +1229,7 @@ jit_value *cond (void) const { return argument (0); } - std::ostream& print_cond (std::ostream& os) + std::ostream& print_cond (std::ostream& os) const { return cond ()->short_print (os); } @@ -1086,7 +1247,7 @@ size_t sucessor_count (void) const { return 2; } - virtual std::ostream& print (std::ostream& os, size_t indent) + virtual std::ostream& print (std::ostream& os, size_t indent) const { print_indent (os, indent) << "cond_break: "; print_cond (os) << ", "; @@ -1122,11 +1283,11 @@ return mfunction.get_overload (argument_types ()); } - virtual std::ostream& print (std::ostream& os, size_t indent) + virtual std::ostream& print (std::ostream& os, size_t indent) const { print_indent (os, indent); - if (use_count ()) + if (use_count () || tag ()) short_print (os) << " = "; os << "call " << mfunction.name () << " ("; @@ -1150,11 +1311,16 @@ jit_extract_argument : public jit_instruction { public: - jit_extract_argument (jit_type *atype, const std::string& aname) + jit_extract_argument (jit_type *atype, jit_variable *var) : jit_instruction () { stash_type (atype); - stash_tag (aname); + stash_tag (var); + } + + const std::string& name (void) const + { + return tag ()->name (); } const jit_function::overload& overload (void) const @@ -1162,10 +1328,12 @@ return jit_typeinfo::cast (type (), jit_typeinfo::get_any ()); } - virtual std::ostream& print (std::ostream& os, size_t indent) + virtual std::ostream& print (std::ostream& os, size_t indent) const { print_indent (os, indent); - return short_print (os) << " = extract: " << tag (); + os << "exract "; + short_print (os); + return os; } JIT_VALUE_ACCEPT (extract_argument) @@ -1175,10 +1343,15 @@ jit_store_argument : public jit_instruction { public: - jit_store_argument (const std::string& aname, jit_value *aresult) - : jit_instruction (aresult) + jit_store_argument (jit_variable *var) + : jit_instruction (var) { - stash_tag (aname); + stash_tag (var); + } + + const std::string& name (void) const + { + return tag ()->name (); } const jit_function::overload& overload (void) const @@ -1201,10 +1374,11 @@ return result ()->to_llvm (); } - virtual std::ostream& print (std::ostream& os, size_t indent) + virtual std::ostream& print (std::ostream& os, size_t indent) const { jit_value *res = result (); - print_indent (os, indent) << tag () << " <- "; + print_indent (os, indent) << "store "; + short_print (os) << " = "; return res->short_print (os); } @@ -1338,158 +1512,6 @@ void visit_while_command (tree_while_command&); void visit_do_until_command (tree_do_until_command&); -private: - std::vector<std::pair<std::string, bool> > arguments; - type_bound_vector bounds; - - class - variable_map - { - // internal variable map - typedef std::map<std::string, jit_value *> ivar_map; - public: - typedef ivar_map::iterator iterator; - typedef ivar_map::const_iterator const_iterator; - - variable_map (variable_map *aparent, jit_block *ablock) : mparent (aparent), - mblock (ablock) - {} - - virtual ~variable_map () {} - - variable_map *parent (void) const { return mparent; } - - jit_block *block (void) const { return mblock; } - - jit_value *get (const std::string& name) - { - ivar_map::iterator iter = vars.find (name); - if (iter != vars.end ()) - return iter->second; - - if (mparent) - { - jit_value *pval = mparent->get (name); - return insert (name, pval); - } - - return insert (name, 0); - } - - jit_value *set (const std::string& name, jit_value *val) - { - get (name); // force insertion - return vars[name] = val; - } - - iterator begin (void) { return vars.begin (); } - const_iterator begin (void) const { return vars.begin (); } - - iterator end (void) { return vars.end (); } - const_iterator end (void) const { return vars.end (); } - - size_t size (void) const { return vars.size (); } - protected: - virtual jit_value *insert (const std::string& name, jit_value *pval) = 0; - - ivar_map vars; - private: - variable_map *mparent; - jit_block *mblock; - }; - - class - toplevel_map : public variable_map - { - public: - toplevel_map (jit_convert& aconvert, jit_block *aentry) - : variable_map (0, aentry), convert (aconvert) {} - protected: - virtual jit_value *insert (const std::string& name, jit_value *pval); - private: - jit_convert& convert; - }; - - class - for_map : public variable_map - { - public: - typedef variable_map::iterator iterator; - typedef variable_map::const_iterator const_iterator; - - for_map (variable_map *aparent, jit_block *ablock) - : variable_map (aparent, ablock) - { - // force insertion of all phi nodes - for (iterator iter = aparent->begin (); iter != aparent->end (); ++iter) - get (iter->first); - } - - void finish_phi (variable_map& from) - { - jit_block *for_body = block (); - for (jit_block::iterator iter = for_body->begin (); - iter != for_body->end () && dynamic_cast<jit_phi *> (*iter); ++iter) - { - jit_instruction *node = *iter; - if (! node->argument (1)) - node->stash_argument (1, from.get (node->tag ())); - } - } - protected: - virtual jit_value *insert (const std::string& name, jit_value *pval) - { - jit_phi *ret = new jit_phi (2); - ret->stash_tag (name); - block ()->prepend (ret); - ret->stash_argument (0, pval); - return vars[name] = ret; - } - }; - - class - compound_map : public variable_map - { - public: - compound_map (variable_map *aparent) : variable_map (aparent, 0) - {} - protected: - virtual jit_value *insert (const std::string&, jit_value *pval) - { - return pval; - } - }; - - - variable_map *variables; - - // used instead of return values from visit_* functions - jit_value *result; - - jit_block *block; - jit_block *final_block; - - llvm::Function *function; - - std::list<jit_block *> blocks; - - std::list<jit_instruction *> worklist; - - std::list<jit_value *> constants; - - std::list<jit_value *> all_values; - - void do_assign (const std::string& lhs, jit_value *rhs, bool print); - - jit_value *visit (tree *tee) { return visit (*tee); } - - jit_value *visit (tree& tee); - - void append_users (jit_value *v) - { - for (jit_use *use = v->first_use (); use; use = use->next ()) - worklist.push_back (use->user ()); - } // this would be easier with variadic templates template <typename T> @@ -1523,19 +1545,87 @@ track_value (ret); return ret; } +private: + std::vector<std::pair<std::string, bool> > arguments; + type_bound_vector bounds; + + // used instead of return values from visit_* functions + jit_instruction *result; + + jit_block *entry_block; + + jit_block *block; + + llvm::Function *function; + + std::list<jit_block *> blocks; + + std::list<jit_instruction *> worklist; + + std::list<jit_value *> constants; + + std::list<jit_value *> all_values; + + size_t iterator_count; + + typedef std::map<std::string, jit_variable *> vmap_t; + vmap_t vmap; + + jit_variable *get_variable (const std::string& vname); + + jit_instruction *do_assign (const std::string& lhs, jit_instruction *rhs, + bool print); + + jit_instruction *visit (tree *tee) { return visit (*tee); } + + jit_instruction *visit (tree& tee); + + void append_users (jit_value *v) + { + for (jit_use *use = v->first_use (); use; use = use->next ()) + worklist.push_back (use->user ()); + } void track_value (jit_value *value) { - if (value->type () && ! dynamic_cast<jit_instruction *>(value)) + if (value->type ()) constants.push_back (value); all_values.push_back (value); } - // place phi nodes in the current block to merge ref with variables - // we assume the same number of deffinitions - void merge (jit_block *merge_block, variable_map& merge_vars, - jit_block *incomming_block, - const variable_map& incomming_vars); + void construct_ssa (jit_block *final_block); + + void print_blocks (const std::string& header) + { + std::cout << "-------------------- " << header << " --------------------\n"; + for (std::list<jit_block *>::iterator iter = blocks.begin (); + iter != blocks.end (); ++iter) + { + assert (*iter); + (*iter)->print (std::cout, 0); + } + std::cout << std::endl; + } + + void print_dom (void) + { + std::cout << "-------------------- dom info --------------------\n"; + for (std::list<jit_block *>::iterator iter = blocks.begin (); + iter != blocks.end (); ++iter) + { + assert (*iter); + (*iter)->print_dom (std::cout); + } + std::cout << std::endl; + } + + typedef std::list<jit_block *> break_list; + + bool breaking; // true if we are breaking OR continuing + break_list breaks; + break_list continues; + + void finish_breaks (jit_block *dest, const break_list& lst); // this case is much simpler, just convert from the jit ir to llvm class @@ -1544,8 +1634,7 @@ public: llvm::Function *convert (llvm::Module *module, const std::vector<std::pair<std::string, bool> >& args, - const std::list<jit_block *>& blocks, - const std::list<jit_value *>& constants); + const std::list<jit_block *>& blocks); #define JIT_METH(clname) \ virtual void visit (jit_ ## clname&);