Mercurial > octave-nkf
diff src/pt-jit.cc @ 14935:5801e031a3b5
Place releases after last use and generalize dom visiting
author | Max Brister <max@2bass.com> |
---|---|
date | Mon, 04 Jun 2012 13:10:44 -0500 |
parents | 1f914446157d |
children | 32deb562ae77 |
line wrap: on
line diff
--- a/src/pt-jit.cc Sun Jun 03 16:30:21 2012 -0500 +++ b/src/pt-jit.cc Mon Jun 04 13:10:44 2012 -0500 @@ -339,30 +339,30 @@ fn = create_function ("octave_jit_grab_any", any, any); engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_grab_any)); - grab_fn.add_overload (fn, false, any, any); + grab_fn.add_overload (fn, false, false, any, any); grab_fn.stash_name ("grab"); // grab scalar fn = create_identity (scalar); - grab_fn.add_overload (fn, false, scalar, scalar); + grab_fn.add_overload (fn, false, false, scalar, scalar); // grab index fn = create_identity (index); - grab_fn.add_overload (fn, false, index, index); + grab_fn.add_overload (fn, false, false, index, index); // release any fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any)); - release_fn.add_overload (fn, false, 0, any); + release_fn.add_overload (fn, false, false, 0, any); release_fn.stash_name ("release"); // release scalar fn = create_identity (scalar); - release_fn.add_overload (fn, false, 0, scalar); + release_fn.add_overload (fn, false, false, 0, scalar); // release index fn = create_identity (index); - release_fn.add_overload (fn, false, 0, index); + release_fn.add_overload (fn, false, false, 0, index); // now for binary scalar operations // FIXME: Finish all operations @@ -401,7 +401,7 @@ builder.CreateRet (zero); } llvm::verifyFunction (*fn); - for_init_fn.add_overload (fn, false, index, range); + for_init_fn.add_overload (fn, false, false, index, range); // bounds check for for loop for_check_fn.stash_name ("for_check"); @@ -417,7 +417,7 @@ builder.CreateRet (ret); } llvm::verifyFunction (*fn); - for_check_fn.add_overload (fn, false, boolean, range, index); + for_check_fn.add_overload (fn, false, false, boolean, range, index); // index variabe for for loop for_index_fn.stash_name ("for_index"); @@ -437,7 +437,7 @@ builder.CreateRet (ret); } llvm::verifyFunction (*fn); - for_index_fn.add_overload (fn, false, scalar, range, index); + for_index_fn.add_overload (fn, false, false, scalar, range, index); // logically true // FIXME: Check for NaN @@ -450,14 +450,14 @@ builder.CreateRet (ret); } llvm::verifyFunction (*fn); - logically_true.add_overload (fn, true, boolean, scalar); + logically_true.add_overload (fn, true, false, boolean, scalar); fn = create_function ("octave_logically_true_bool", boolean, boolean); body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); builder.CreateRet (fn->arg_begin ()); llvm::verifyFunction (*fn); - logically_true.add_overload (fn, false, boolean, boolean); + logically_true.add_overload (fn, false, false, boolean, boolean); logically_true.stash_name ("logically_true"); casts[any->type_id ()].stash_name ("(any)"); @@ -466,20 +466,20 @@ // cast any <- scalar fn = create_function ("octave_jit_cast_any_scalar", any, scalar); engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_scalar)); - casts[any->type_id ()].add_overload (fn, false, any, scalar); + casts[any->type_id ()].add_overload (fn, false, false, any, scalar); // cast scalar <- any fn = create_function ("octave_jit_cast_scalar_any", scalar, any); engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any)); - casts[scalar->type_id ()].add_overload (fn, false, scalar, any); + casts[scalar->type_id ()].add_overload (fn, false, false, scalar, any); // cast any <- any fn = create_identity (any); - casts[any->type_id ()].add_overload (fn, false, any, any); + casts[any->type_id ()].add_overload (fn, false, false, any, any); // cast scalar <- scalar fn = create_identity (scalar); - casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar); + casts[scalar->type_id ()].add_overload (fn, false, false, scalar, scalar); } void @@ -494,7 +494,7 @@ ty->to_llvm ()); engine->addGlobalMapping (fn, call); - jit_function::overload ol (fn, false, 0, string, ty); + jit_function::overload ol (fn, false, true, 0, string, ty); print_fn.add_overload (ol); } @@ -517,7 +517,7 @@ builder.CreateRet (ret); llvm::verifyFunction (*fn); - jit_function::overload ol(fn, false, ty, ty, ty); + jit_function::overload ol(fn, false, false, ty, ty, ty); binary_ops[op].add_overload (ol); } @@ -539,7 +539,7 @@ builder.CreateRet (ret); llvm::verifyFunction (*fn); - jit_function::overload ol (fn, false, boolean, ty, ty); + jit_function::overload ol (fn, false, false, boolean, ty, ty); binary_ops[op].add_overload (ol); } @@ -561,7 +561,7 @@ builder.CreateRet (ret); llvm::verifyFunction (*fn); - jit_function::overload ol (fn, false, boolean, ty, ty); + jit_function::overload ol (fn, false, false, boolean, ty, ty); binary_ops[op].add_overload (ol); } @@ -666,6 +666,13 @@ // -------------------- jit_instruction -------------------- void +jit_instruction::remove (void) +{ + if (mparent) + mparent->remove (mlocation); +} + +void jit_instruction::push_variable (void) { if (tag ()) @@ -715,23 +722,40 @@ jit_block::prepend (jit_instruction *instr) { instructions.push_front (instr); - instr->stash_parent (this); + instr->stash_parent (this, instructions.begin ()); return instr; } jit_instruction * +jit_block::prepend_after_phi (jit_instruction *instr) +{ + // FIXME: Make this O(1) + for (iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *temp = *iter; + if (! temp->is_phi ()) + { + insert_before (iter, instr); + return instr; + } + } + + return append (instr); +} + +jit_instruction * jit_block::append (jit_instruction *instr) { instructions.push_back (instr); - instr->stash_parent (this); + instr->stash_parent (this, --instructions.end ()); return instr; } jit_instruction * jit_block::insert_before (iterator loc, jit_instruction *instr) { - instructions.insert (loc, instr); - instr->stash_parent (this); + iterator iloc = instructions.insert (loc, instr); + instr->stash_parent (this, iloc); return instr; } @@ -739,8 +763,8 @@ jit_block::insert_after (iterator loc, jit_instruction *instr) { ++loc; - instructions.insert (loc, instr); - instr->stash_parent (this); + iterator iloc = instructions.insert (loc, instr); + instr->stash_parent (this, iloc); return instr; } @@ -751,7 +775,7 @@ return 0; jit_instruction *last = instructions.back (); - return dynamic_cast<jit_terminator *> (last); + return last->to_terminator (); } jit_block * @@ -925,61 +949,8 @@ } void -jit_block::finish_phi (jit_block *apred) -{ - size_t pred_idx = pred_index (apred); - for (iterator iter = begin (); iter != end () - && dynamic_cast<jit_phi *> (*iter); ++iter) - { - jit_instruction *phi = *iter; - jit_variable *var = phi->tag (); - phi->stash_argument (pred_idx, var->top ()); - } -} - -void -jit_block::do_construct_ssa (jit_convert& convert, size_t visit_count) +jit_block::pop_all (void) { - if (mvisit_count > visit_count) - return; - ++mvisit_count; - - for (iterator iter = begin (); iter != end (); ++iter) - { - jit_instruction *instr = *iter; - bool isphi = dynamic_cast<jit_phi *> (instr); - - if (! isphi) - { - for (size_t i = 0; i < instr->argument_count (); ++i) - { - jit_variable *var; - var = dynamic_cast<jit_variable *> (instr->argument (i)); - if (var) - instr->stash_argument (i, var->top ()); - } - - // FIXME: Remove need for jit_store_argument dynamic cast - jit_variable *tag = instr->tag (); - if (tag && tag->has_top () - && ! dynamic_cast<jit_store_argument *> (instr)) - { - jit_call *rel = convert.create<jit_call> (jit_typeinfo::release, - tag->top ()); - insert_after (iter, rel); - ++iter; - } - } - - instr->push_variable (); - } - - for (size_t i = 0; i < succ_count (); ++i) - succ (i)->finish_phi (this); - - for (size_t i = 0; i < dom_succ.size (); ++i) - dom_succ[i]->do_construct_ssa (convert, visit_count); - for (iterator iter = begin (); iter != end (); ++iter) { jit_instruction *instr = *iter; @@ -1021,6 +992,18 @@ // -------------------- jit_call -------------------- bool +jit_call::dead (void) const +{ + return ! has_side_effects () && use_count () == 0; +} + +bool +jit_call::almost_dead (void) const +{ + return ! has_side_effects () && use_count () <= 1; +} + +bool jit_call::infer (void) { // FIXME: explain algorithm @@ -1082,10 +1065,6 @@ iter != constants.end (); ++iter) append_users (*iter); -#ifdef OCTAVE_JIT_DEBUG - print_blocks ("octave jit ir"); -#endif - // FIXME: Describe algorithm here while (worklist.size ()) { @@ -1096,6 +1075,8 @@ append_users (next); } + place_releases (); + #ifdef OCTAVE_JIT_DEBUG std::cout << "-------------------- Compiling tree --------------------\n"; std::cout << tee.str_print_code () << std::endl; @@ -1107,7 +1088,7 @@ for (jit_block::iterator iter = entry_block->begin (); iter != entry_block->end (); ++iter) { - if (jit_extract_argument *extract = dynamic_cast<jit_extract_argument *> (*iter)) + if (jit_extract_argument *extract = (*iter)->to_extract_argument ()) arguments.push_back (std::make_pair (extract->name (), true)); } @@ -1714,7 +1695,7 @@ entry_block->compute_df (); entry_block->create_dom_tree (); - // insert phi nodes where needed + // insert phi nodes where needed, this is done on a per variable basis for (vmap_t::iterator iter = vmap.begin (); iter != vmap.end (); ++iter) { jit_block::df_set visited, added_phi; @@ -1748,13 +1729,59 @@ } } - entry_block->construct_ssa (*this); + entry_block->visit_dom (&jit_convert::do_construct_ssa, &jit_block::pop_all); } void -jit_convert::finish_breaks (jit_block *dest, const break_list& lst) +jit_convert::do_construct_ssa (jit_block& block) { - for (break_list::const_iterator iter = lst.begin (); iter != lst.end (); + // replace variables with their current SSA value + for (jit_block::iterator iter = block.begin (); iter != block.end (); ++iter) + { + jit_instruction *instr = *iter; + if (! instr->is_phi ()) + { + for (size_t i = 0; i < instr->argument_count (); ++i) + { + jit_variable *var; + var = instr->argument_variable (i); + assert (var == instr->argument (i)->to_variable ()); + assert (var == dynamic_cast<jit_variable *> (instr->argument (i))); + if (var) + instr->stash_argument (i, var->top ()); + } + } + + instr->push_variable (); + } + + // finish phi nodes of sucessors + for (size_t i = 0; i < block.succ_count (); ++i) + { + jit_block *finish = block.succ (i); + size_t pred_idx = finish->pred_index (&block); + + for (jit_block::iterator iter = finish->begin (); iter != finish->end () + && (*iter)->is_phi (); ++iter) + { + jit_instruction *phi = *iter; + jit_variable *var = phi->tag (); + phi->stash_argument (pred_idx, var->top ()); + } + } +} + +void +jit_convert::place_releases (void) +{ + jit_convert::release_placer placer (*this); + entry_block->visit_dom (placer, &jit_block::pop_all); +} + +void +jit_convert::finish_breaks (jit_block *dest, const block_list& lst) +{ + for (block_list::const_iterator iter = lst.begin (); iter != lst.end (); ++iter) { jit_block *b = *iter; @@ -1762,6 +1789,42 @@ } } +// -------------------- jit_convert::release_placer -------------------- +void +jit_convert::release_placer::operator() (jit_block& block) +{ + for (jit_block::iterator iter = block.begin (); iter != block.end (); ++iter) + { + jit_instruction *instr = *iter; + for (size_t i = 0; i < instr->argument_count (); ++i) + { + jit_instruction *arg = instr->argument_instruction (i); + if (arg && arg->tag ()) + { + jit_variable *tag = arg->tag (); + tag->stash_last_use (instr); + } + } + + jit_variable *tag = instr->tag (); + if (tag && ! (instr->is_phi () || instr->is_store_argument ()) + && tag->has_top ()) + { + jit_instruction *last_use = tag->last_use (); + jit_call *release = convert.create<jit_call> (jit_typeinfo::release, + tag->top ()); + release->infer (); + if (last_use && last_use->parent () == &block + && ! last_use->is_phi ()) + block.insert_after (last_use->location (), release); + else + block.prepend_after_phi (release); + } + + instr->push_variable (); + } +} + // -------------------- jit_convert::convert_llvm -------------------- llvm::Function * jit_convert::convert_llvm::convert (llvm::Module *module, @@ -1812,42 +1875,10 @@ { jit_block& block = **biter; for (jit_block::iterator piter = block.begin (); - piter != block.end () && dynamic_cast<jit_phi *> (*piter); ++piter) + piter != block.end () && (*piter)->is_phi (); ++piter) { - // our phi nodes don't have to have the same incomming type, - // so we do casts here jit_instruction *phi = *piter; - jit_block *pblock = phi->parent (); - llvm::PHINode *llvm_phi = llvm::cast<llvm::PHINode> (phi->to_llvm ()); - for (size_t i = 0; i < phi->argument_count (); ++i) - { - llvm::BasicBlock *pred = pblock->pred_llvm (i); - if (phi->argument_type_llvm (i) == phi->type_llvm ()) - { - llvm_phi->addIncoming (phi->argument_llvm (i), pred); - } - else - { - // add cast right before pred terminator - builder.SetInsertPoint (--pred->end ()); - - const jit_function::overload& ol - = jit_typeinfo::cast (phi->type (), - phi->argument_type (i)); - if (! ol.function) - { - std::stringstream ss; - ss << "No cast for phi(" << i << "): "; - phi->print (ss); - fail (ss.str ()); - } - - llvm::Value *casted; - casted = builder.CreateCall (ol.function, - phi->argument_llvm (i)); - llvm_phi->addIncoming (casted, pred); - } - } + finish_phi (phi); } } @@ -1864,6 +1895,87 @@ } void +jit_convert::convert_llvm::finish_phi (jit_instruction *phi) +{ + jit_block *pblock = phi->parent (); + llvm::PHINode *llvm_phi = llvm::cast<llvm::PHINode> (phi->to_llvm ()); + + bool can_remove = llvm_phi->use_empty (); + if (! can_remove && llvm_phi->hasOneUse () && phi->use_count () == 1) + { + jit_instruction *user = phi->first_use ()->user (); + can_remove = user->is_call (); // must be a remove + } + + if (can_remove) + { + // replace with releases along each incomming branch + while (! llvm_phi->use_empty ()) + { + llvm::Instruction *llvm_instr; + llvm_instr = llvm::cast<llvm::Instruction> (llvm_phi->use_back ()); + llvm_instr->eraseFromParent (); + } + + llvm_phi->eraseFromParent (); + phi->stash_llvm (0); + + for (size_t i = 0; i < phi->argument_count (); ++i) + { + jit_value *arg = phi->argument (i); + if (arg->has_llvm () && phi->argument_type (i) != phi->type ()) + { + llvm::BasicBlock *pred = pblock->pred_llvm (i); + builder.SetInsertPoint (--pred->end ()); + const jit_function::overload& ol + = jit_typeinfo::get_release (phi->argument_type (i)); + if (! ol.function) + { + std::stringstream ss; + ss << "No release for phi(" << i << "): "; + phi->print (ss); + fail (ss.str ()); + } + + builder.CreateCall (ol.function, phi->argument_llvm (i)); + } + } + } + else + { + for (size_t i = 0; i < phi->argument_count (); ++i) + { + llvm::BasicBlock *pred = pblock->pred_llvm (i); + if (phi->argument_type (i) == phi->type ()) + { + llvm_phi->addIncoming (phi->argument_llvm (i), pred); + } + else + { + // add cast right before pred terminator + builder.SetInsertPoint (--pred->end ()); + + const jit_function::overload& ol + = jit_typeinfo::cast (phi->type (), + phi->argument_type (i)); + if (! ol.function) + { + std::stringstream ss; + ss << "No cast for phi(" << i << "): "; + phi->print (ss); + fail (ss.str ()); + } + + llvm::Value *casted; + casted = builder.CreateCall (ol.function, + phi->argument_llvm (i)); + llvm_phi->addIncoming (casted, pred); + } + } + } +} + +void jit_convert::convert_llvm::visit (jit_const_string& cs) { cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ()));