diff src/pt-jit.cc @ 14935:5801e031a3b5

Place releases after last use and generalize dom visiting
author Max Brister <max@2bass.com>
date Mon, 04 Jun 2012 13:10:44 -0500
parents 1f914446157d
children 32deb562ae77
line wrap: on
line diff
--- a/src/pt-jit.cc	Sun Jun 03 16:30:21 2012 -0500
+++ b/src/pt-jit.cc	Mon Jun 04 13:10:44 2012 -0500
@@ -339,30 +339,30 @@
   fn = create_function ("octave_jit_grab_any", any, any);
                         
   engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_grab_any));
-  grab_fn.add_overload (fn, false, any, any);
+  grab_fn.add_overload (fn, false, false, any, any);
   grab_fn.stash_name ("grab");
 
   // grab scalar
   fn = create_identity (scalar);
-  grab_fn.add_overload (fn, false, scalar, scalar);
+  grab_fn.add_overload (fn, false, false, scalar, scalar);
 
   // grab index
   fn = create_identity (index);
-  grab_fn.add_overload (fn, false, index, index);
+  grab_fn.add_overload (fn, false, false, index, index);
 
   // release any
   fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ());
   engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any));
-  release_fn.add_overload (fn, false, 0, any);
+  release_fn.add_overload (fn, false, false, 0, any);
   release_fn.stash_name ("release");
 
   // release scalar
   fn = create_identity (scalar);
-  release_fn.add_overload (fn, false, 0, scalar);
+  release_fn.add_overload (fn, false, false, 0, scalar);
 
   // release index
   fn = create_identity (index);
-  release_fn.add_overload (fn, false, 0, index);
+  release_fn.add_overload (fn, false, false, 0, index);
 
   // now for binary scalar operations
   // FIXME: Finish all operations
@@ -401,7 +401,7 @@
     builder.CreateRet (zero);
   }
   llvm::verifyFunction (*fn);
-  for_init_fn.add_overload (fn, false, index, range);
+  for_init_fn.add_overload (fn, false, false, index, range);
 
   // bounds check for for loop
   for_check_fn.stash_name ("for_check");
@@ -417,7 +417,7 @@
     builder.CreateRet (ret);
   }
   llvm::verifyFunction (*fn);
-  for_check_fn.add_overload (fn, false, boolean, range, index);
+  for_check_fn.add_overload (fn, false, false, boolean, range, index);
 
   // index variabe for for loop
   for_index_fn.stash_name ("for_index");
@@ -437,7 +437,7 @@
     builder.CreateRet (ret);
   }
   llvm::verifyFunction (*fn);
-  for_index_fn.add_overload (fn, false, scalar, range, index);
+  for_index_fn.add_overload (fn, false, false, scalar, range, index);
 
   // logically true
   // FIXME: Check for NaN
@@ -450,14 +450,14 @@
     builder.CreateRet (ret);
   }
   llvm::verifyFunction (*fn);
-  logically_true.add_overload (fn, true, boolean, scalar);
+  logically_true.add_overload (fn, true, false, boolean, scalar);
 
   fn = create_function ("octave_logically_true_bool", boolean, boolean);
   body = llvm::BasicBlock::Create (context, "body", fn);
   builder.SetInsertPoint (body);
   builder.CreateRet (fn->arg_begin ());
   llvm::verifyFunction (*fn);
-  logically_true.add_overload (fn, false, boolean, boolean);
+  logically_true.add_overload (fn, false, false, boolean, boolean);
   logically_true.stash_name ("logically_true");
 
   casts[any->type_id ()].stash_name ("(any)");
@@ -466,20 +466,20 @@
   // cast any <- scalar
   fn = create_function ("octave_jit_cast_any_scalar", any, scalar);
   engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_scalar));
-  casts[any->type_id ()].add_overload (fn, false, any, scalar);
+  casts[any->type_id ()].add_overload (fn, false, false, any, scalar);
 
   // cast scalar <- any
   fn = create_function ("octave_jit_cast_scalar_any", scalar, any);
   engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any));
-  casts[scalar->type_id ()].add_overload (fn, false, scalar, any);
+  casts[scalar->type_id ()].add_overload (fn, false, false, scalar, any);
 
   // cast any <- any
   fn = create_identity (any);
-  casts[any->type_id ()].add_overload (fn, false, any, any);
+  casts[any->type_id ()].add_overload (fn, false, false, any, any);
 
   // cast scalar <- scalar
   fn = create_identity (scalar);
-  casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar);
+  casts[scalar->type_id ()].add_overload (fn, false, false, scalar, scalar);
 }
 
 void
@@ -494,7 +494,7 @@
                                         ty->to_llvm ());
   engine->addGlobalMapping (fn, call);
 
-  jit_function::overload ol (fn, false, 0, string, ty);
+  jit_function::overload ol (fn, false, true, 0, string, ty);
   print_fn.add_overload (ol);
 }
 
@@ -517,7 +517,7 @@
   builder.CreateRet (ret);
   llvm::verifyFunction (*fn);
 
-  jit_function::overload ol(fn, false, ty, ty, ty);
+  jit_function::overload ol(fn, false, false, ty, ty, ty);
   binary_ops[op].add_overload (ol);
 }
 
@@ -539,7 +539,7 @@
   builder.CreateRet (ret);
   llvm::verifyFunction (*fn);
 
-  jit_function::overload ol (fn, false, boolean, ty, ty);
+  jit_function::overload ol (fn, false, false, boolean, ty, ty);
   binary_ops[op].add_overload (ol);
 }
 
@@ -561,7 +561,7 @@
   builder.CreateRet (ret);
   llvm::verifyFunction (*fn);
 
-  jit_function::overload ol (fn, false, boolean, ty, ty);
+  jit_function::overload ol (fn, false, false, boolean, ty, ty);
   binary_ops[op].add_overload (ol);
 }
 
@@ -666,6 +666,13 @@
 
 // -------------------- jit_instruction --------------------
 void
+jit_instruction::remove (void)
+{
+  if (mparent)
+    mparent->remove (mlocation);
+}
+
+void
 jit_instruction::push_variable (void)
 {
   if (tag ())
@@ -715,23 +722,40 @@
 jit_block::prepend (jit_instruction *instr)
 {
   instructions.push_front (instr);
-  instr->stash_parent (this);
+  instr->stash_parent (this, instructions.begin ());
   return instr;
 }
 
 jit_instruction *
+jit_block::prepend_after_phi (jit_instruction *instr)
+{
+  // FIXME: Make this O(1)
+  for (iterator iter = begin (); iter != end (); ++iter)
+    {
+      jit_instruction *temp = *iter;
+      if (! temp->is_phi ())
+        {
+          insert_before (iter, instr);
+          return instr;
+        }
+    }
+
+  return append (instr);
+}
+
+jit_instruction *
 jit_block::append (jit_instruction *instr)
 {
   instructions.push_back (instr);
-  instr->stash_parent (this);
+  instr->stash_parent (this, --instructions.end ());
   return instr;
 }
 
 jit_instruction *
 jit_block::insert_before (iterator loc, jit_instruction *instr)
 {
-  instructions.insert (loc, instr);
-  instr->stash_parent (this);
+  iterator iloc = instructions.insert (loc, instr);
+  instr->stash_parent (this, iloc);
   return instr;
 }
 
@@ -739,8 +763,8 @@
 jit_block::insert_after (iterator loc, jit_instruction *instr)
 {
   ++loc;
-  instructions.insert (loc, instr);
-  instr->stash_parent (this);
+  iterator iloc = instructions.insert (loc, instr);
+  instr->stash_parent (this, iloc);
   return instr;
 }
 
@@ -751,7 +775,7 @@
     return 0;
 
   jit_instruction *last = instructions.back ();
-  return dynamic_cast<jit_terminator *> (last);
+  return last->to_terminator ();
 }
 
 jit_block *
@@ -925,61 +949,8 @@
 }
 
 void
-jit_block::finish_phi (jit_block *apred)
-{
-  size_t pred_idx = pred_index (apred);
-  for (iterator iter = begin (); iter != end ()
-         && dynamic_cast<jit_phi *> (*iter); ++iter)
-    {
-      jit_instruction *phi = *iter;
-      jit_variable *var = phi->tag ();
-      phi->stash_argument (pred_idx, var->top ());
-    }
-}
-
-void
-jit_block::do_construct_ssa (jit_convert& convert, size_t visit_count)
+jit_block::pop_all (void)
 {
-  if (mvisit_count > visit_count)
-    return;
-  ++mvisit_count;
-
-  for (iterator iter = begin (); iter != end (); ++iter)
-    {
-      jit_instruction *instr = *iter;
-      bool isphi = dynamic_cast<jit_phi *> (instr);
-
-      if (! isphi)
-        {
-          for (size_t i = 0; i < instr->argument_count (); ++i)
-            {
-              jit_variable *var;
-              var = dynamic_cast<jit_variable *> (instr->argument (i));
-              if (var)
-                instr->stash_argument (i, var->top ());
-            }
-
-          // FIXME: Remove need for jit_store_argument dynamic cast
-          jit_variable *tag = instr->tag ();
-          if (tag && tag->has_top ()
-              && ! dynamic_cast<jit_store_argument *> (instr))
-            {
-              jit_call *rel = convert.create<jit_call> (jit_typeinfo::release,
-                                                        tag->top ());
-              insert_after (iter, rel);
-              ++iter;
-            }
-        }
-
-      instr->push_variable ();
-    }
-
-  for (size_t i = 0; i < succ_count (); ++i)
-    succ (i)->finish_phi (this);
-
-  for (size_t i = 0; i < dom_succ.size (); ++i)
-    dom_succ[i]->do_construct_ssa (convert, visit_count);
-
   for (iterator iter = begin (); iter != end (); ++iter)
     {
       jit_instruction *instr = *iter;
@@ -1021,6 +992,18 @@
 
 // -------------------- jit_call --------------------
 bool
+jit_call::dead (void) const
+{
+  return ! has_side_effects () && use_count () == 0;
+}
+
+bool
+jit_call::almost_dead (void) const
+{
+  return ! has_side_effects () && use_count () <= 1;
+}
+
+bool
 jit_call::infer (void)
 {
   // FIXME: explain algorithm
@@ -1082,10 +1065,6 @@
        iter != constants.end (); ++iter)
     append_users (*iter);
 
-#ifdef OCTAVE_JIT_DEBUG
-  print_blocks ("octave jit ir");
-#endif
-
   // FIXME: Describe algorithm here
   while (worklist.size ())
     {
@@ -1096,6 +1075,8 @@
         append_users (next);
     }
 
+  place_releases ();
+
 #ifdef OCTAVE_JIT_DEBUG
   std::cout << "-------------------- Compiling tree --------------------\n";
   std::cout << tee.str_print_code () << std::endl;
@@ -1107,7 +1088,7 @@
   for (jit_block::iterator iter = entry_block->begin ();
        iter != entry_block->end (); ++iter)
     {
-      if (jit_extract_argument *extract = dynamic_cast<jit_extract_argument *> (*iter))
+      if (jit_extract_argument *extract = (*iter)->to_extract_argument ())
         arguments.push_back (std::make_pair (extract->name (), true));
     }
 
@@ -1714,7 +1695,7 @@
   entry_block->compute_df ();
   entry_block->create_dom_tree ();
 
-  // insert phi nodes where needed
+  // insert phi nodes where needed, this is done on a per variable basis
   for (vmap_t::iterator iter = vmap.begin (); iter != vmap.end (); ++iter)
     {
       jit_block::df_set visited, added_phi;
@@ -1748,13 +1729,59 @@
         }
     }
 
-  entry_block->construct_ssa (*this);
+  entry_block->visit_dom (&jit_convert::do_construct_ssa, &jit_block::pop_all);
 }
 
 void
-jit_convert::finish_breaks (jit_block *dest, const break_list& lst)
+jit_convert::do_construct_ssa (jit_block& block)
 {
-  for (break_list::const_iterator iter = lst.begin (); iter != lst.end ();
+  // replace variables with their current SSA value
+  for (jit_block::iterator iter = block.begin (); iter != block.end (); ++iter)
+    {
+      jit_instruction *instr = *iter;
+      if (! instr->is_phi ())
+        {
+          for (size_t i = 0; i < instr->argument_count (); ++i)
+            {
+              jit_variable *var;
+              var = instr->argument_variable (i);
+              assert (var == instr->argument (i)->to_variable ());
+              assert (var == dynamic_cast<jit_variable *> (instr->argument (i)));
+              if (var)
+                instr->stash_argument (i, var->top ());
+            }
+        }
+
+      instr->push_variable ();
+    }
+
+  // finish phi nodes of sucessors
+  for (size_t i = 0; i < block.succ_count (); ++i)
+    {
+      jit_block *finish = block.succ (i);
+      size_t pred_idx = finish->pred_index (&block);
+
+      for (jit_block::iterator iter = finish->begin (); iter != finish->end ()
+             && (*iter)->is_phi (); ++iter)
+        {
+          jit_instruction *phi = *iter;
+          jit_variable *var = phi->tag ();
+          phi->stash_argument (pred_idx, var->top ());
+        }
+    }
+}
+
+void
+jit_convert::place_releases (void)
+{
+  jit_convert::release_placer placer (*this);
+  entry_block->visit_dom (placer, &jit_block::pop_all);
+}
+
+void
+jit_convert::finish_breaks (jit_block *dest, const block_list& lst)
+{
+  for (block_list::const_iterator iter = lst.begin (); iter != lst.end ();
        ++iter)
     {
       jit_block *b = *iter;
@@ -1762,6 +1789,42 @@
     }
 }
 
+// -------------------- jit_convert::release_placer --------------------
+void
+jit_convert::release_placer::operator() (jit_block& block)
+{
+  for (jit_block::iterator iter = block.begin (); iter != block.end (); ++iter)
+    {
+      jit_instruction *instr = *iter;
+      for (size_t i = 0; i < instr->argument_count (); ++i)
+        {
+          jit_instruction *arg = instr->argument_instruction (i);
+          if (arg && arg->tag ())
+            {
+              jit_variable *tag = arg->tag ();
+              tag->stash_last_use (instr);
+            }
+        }
+
+      jit_variable *tag = instr->tag ();
+      if (tag && ! (instr->is_phi () || instr->is_store_argument ())
+          && tag->has_top ())
+        {
+          jit_instruction *last_use = tag->last_use ();
+          jit_call *release = convert.create<jit_call> (jit_typeinfo::release,
+                                                        tag->top ());
+          release->infer ();
+          if (last_use && last_use->parent () == &block
+              && ! last_use->is_phi ())
+            block.insert_after (last_use->location (), release);
+          else
+            block.prepend_after_phi (release);
+        }
+
+      instr->push_variable ();
+    }
+}
+
 // -------------------- jit_convert::convert_llvm --------------------
 llvm::Function *
 jit_convert::convert_llvm::convert (llvm::Module *module,
@@ -1812,42 +1875,10 @@
         {
           jit_block& block = **biter;
           for (jit_block::iterator piter = block.begin ();
-               piter != block.end () && dynamic_cast<jit_phi *> (*piter); ++piter)
+               piter != block.end () && (*piter)->is_phi (); ++piter)
             {
-              // our phi nodes don't have to have the same incomming type,
-              // so we do casts here
               jit_instruction *phi = *piter;
-              jit_block *pblock = phi->parent ();
-              llvm::PHINode *llvm_phi = llvm::cast<llvm::PHINode> (phi->to_llvm ());
-              for (size_t i = 0; i < phi->argument_count (); ++i)
-                {
-                  llvm::BasicBlock *pred = pblock->pred_llvm (i);
-                  if (phi->argument_type_llvm (i) == phi->type_llvm ())
-                    {
-                      llvm_phi->addIncoming (phi->argument_llvm (i), pred);
-                    }
-                  else
-                    {
-                      // add cast right before pred terminator
-                      builder.SetInsertPoint (--pred->end ());
-
-                      const jit_function::overload& ol
-                        = jit_typeinfo::cast (phi->type (),
-                                              phi->argument_type (i));
-                      if (! ol.function)
-                        {
-                          std::stringstream ss;
-                          ss << "No cast for phi(" << i << "): ";
-                          phi->print (ss);
-                          fail (ss.str ());
-                        }
-
-                      llvm::Value *casted;
-                      casted = builder.CreateCall (ol.function,
-                                                   phi->argument_llvm (i));
-                      llvm_phi->addIncoming (casted, pred);
-                    }
-                }
+              finish_phi (phi);
             }
         }
 
@@ -1864,6 +1895,87 @@
 }
 
 void
+jit_convert::convert_llvm::finish_phi (jit_instruction *phi)
+{
+  jit_block *pblock = phi->parent ();
+  llvm::PHINode *llvm_phi = llvm::cast<llvm::PHINode> (phi->to_llvm ());
+
+  bool can_remove = llvm_phi->use_empty ();
+  if (! can_remove && llvm_phi->hasOneUse () && phi->use_count () == 1)
+    {
+      jit_instruction *user = phi->first_use ()->user ();
+      can_remove = user->is_call (); // must be a remove
+    }
+
+  if (can_remove)
+    {
+      // replace with releases along each incomming branch
+      while (! llvm_phi->use_empty ())
+        {
+          llvm::Instruction *llvm_instr;
+          llvm_instr = llvm::cast<llvm::Instruction> (llvm_phi->use_back ());
+          llvm_instr->eraseFromParent ();
+        }
+
+      llvm_phi->eraseFromParent ();
+      phi->stash_llvm (0);
+
+      for (size_t i = 0; i < phi->argument_count (); ++i)
+        {
+          jit_value *arg = phi->argument (i);
+          if (arg->has_llvm () && phi->argument_type (i) != phi->type ())
+            {
+              llvm::BasicBlock *pred = pblock->pred_llvm (i);
+              builder.SetInsertPoint (--pred->end ());
+              const jit_function::overload& ol
+                = jit_typeinfo::get_release (phi->argument_type (i));
+              if (! ol.function)
+                {
+                  std::stringstream ss;
+                  ss << "No release for phi(" << i << "): ";
+                  phi->print (ss);
+                  fail (ss.str ());
+                }
+
+              builder.CreateCall (ol.function, phi->argument_llvm (i));
+            }
+        }
+    }
+  else
+    {
+      for (size_t i = 0; i < phi->argument_count (); ++i)
+        {
+          llvm::BasicBlock *pred = pblock->pred_llvm (i);
+          if (phi->argument_type (i) == phi->type ())
+            {
+              llvm_phi->addIncoming (phi->argument_llvm (i), pred);
+            }
+          else
+            {
+              // add cast right before pred terminator
+              builder.SetInsertPoint (--pred->end ());
+
+              const jit_function::overload& ol
+                = jit_typeinfo::cast (phi->type (),
+                                      phi->argument_type (i));
+              if (! ol.function)
+                {
+                  std::stringstream ss;
+                  ss << "No cast for phi(" << i << "): ";
+                  phi->print (ss);
+                  fail (ss.str ());
+                }
+
+              llvm::Value *casted;
+              casted = builder.CreateCall (ol.function,
+                                           phi->argument_llvm (i));
+              llvm_phi->addIncoming (casted, pred);
+            }
+        }
+    }
+}
+
+void
 jit_convert::convert_llvm::visit (jit_const_string& cs)
 {
   cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ()));