changeset 14923:168cb10bb9c5

If, ifelse, and else statements JIT compile now
author Max Brister <max@2bass.com>
date Mon, 28 May 2012 23:19:41 -0500
parents 2e6f83b2f2b9
children d4d9a64db6aa
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 272 insertions(+), 92 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc	Sun May 27 22:57:55 2012 -0500
+++ b/src/pt-jit.cc	Mon May 28 23:19:41 2012 -0500
@@ -612,7 +612,7 @@
 jit_block *
 jit_use::user_parent (void) const
 {
-  return usr->parent ();
+  return muser->parent ();
 }
 
 // -------------------- jit_value --------------------
@@ -660,6 +660,23 @@
   return dynamic_cast<jit_terminator *> (last);
 }
 
+jit_block *
+jit_block::pred (size_t idx) const
+{
+  // FIXME: Make this O(1)
+  
+  // here we get the use in backwards order. This means we preserve phi
+  // information when new blocks are added
+  assert (idx < use_count ());
+  jit_use *use;
+  size_t real_idx = use_count () - idx - 1;
+  size_t i;
+  for (use = first_use (), i = 0; use && i < real_idx; ++i,
+         use = use->next ());
+    
+  return use->user_parent ();
+}
+
 llvm::Value *
 jit_block::pred_terminator_llvm (size_t idx) const
 {
@@ -667,6 +684,17 @@
   return term ? term->to_llvm () : 0;
 }
 
+size_t
+jit_block::pred_index (jit_block *apred) const
+{
+  for (size_t i = 0; i < pred_count (); ++i)
+    if (pred (i) == apred)
+      return i;
+
+  fail ("No such predecessor");
+  return 0; // silly compiler, why you warn?
+}
+
 void
 jit_block::create_merge (llvm::Function *inside, size_t pred_idx)
 {
@@ -704,6 +732,21 @@
   return term ? term->sucessor_count () : 0;
 }
 
+jit_phi *
+jit_block::search_phi (const std::string& tag_name, jit_value *adefault)
+{
+  jit_phi *ret;
+  for (iterator iter = begin (); iter != end ()
+         && (ret = dynamic_cast<jit_phi *> (*iter)); ++iter)
+    if (ret->tag () == tag_name)
+      return ret;
+
+  ret = new jit_phi (pred_count (), adefault);
+  ret->stash_tag (tag_name);
+  prepend (ret);
+  return ret;
+}
+
 llvm::BasicBlock *
 jit_block::to_llvm (void) const
 {
@@ -952,7 +995,7 @@
   // we need to do iter phi manually, for_map handles the rest
   jit_phi *iter_phi = new jit_phi (2);
   iter_phi->stash_tag ("#iter");
-  iter_phi->stash_argument (1, init_iter);
+  iter_phi->stash_argument (0, init_iter);
   body->append (iter_phi);
 
   variable_map *merge_vars = variables;
@@ -978,19 +1021,19 @@
   check = block->append (new jit_call (jit_typeinfo::for_check, control,
                                        iter_inc));
   block->append (new jit_cond_break (check, body, tail));
-  iter_phi->stash_argument (0, iter_inc);
+  iter_phi->stash_argument (1, iter_inc);
   body_vars.finish_phi (*variables);
+  merge (tail, *merge_vars, block, body_vars);
 
   blocks.push_back (tail);
   prot_tail.discard ();
   block = tail;
+  variables = merge_vars;
 
-  variables = merge_vars;
-  merge (body_vars);
   iter_phi = new jit_phi (2);
   iter_phi->stash_tag ("#iter");
-  iter_phi->stash_argument (0, iter_inc);
-  iter_phi->stash_argument (1, init_iter);
+  iter_phi->stash_argument (0, init_iter);
+  iter_phi->stash_argument (1, iter_inc);
   block->append (iter_phi);
   block->append (new jit_call (jit_typeinfo::release, iter_phi));
 }
@@ -1046,15 +1089,126 @@
 }
 
 void
-jit_convert::visit_if_command (tree_if_command&)
+jit_convert::visit_if_command (tree_if_command& cmd)
 {
-  fail ();
+  tree_if_command_list *lst = cmd.cmd_list ();
+  assert (lst); // jwe: Can this be null?
+  lst->accept (*this);
 }
 
 void
-jit_convert::visit_if_command_list (tree_if_command_list&)
+jit_convert::visit_if_command_list (tree_if_command_list& lst)
 {
-  fail ();
+  // Example code:
+  // if a == 1
+  //  c = c + 1;
+  // elseif b == 1
+  //  c = c + 2;
+  // else
+  //  c = c + 3;
+  // endif
+
+  // Generates:
+  // prev_block0: % pred - ?
+  //   #temp.0 = call binary== (a.0, 1)
+  //   cond_break #temp.0, if_body1, ifelse_cond2
+  // if_body1:
+  //   c.1 = call binary+ (c.0, 1)
+  //   break if_tail5
+  // ifelse_cond2:
+  //   #temp.1 = call binary== (b.0, 1)
+  //   cond_break #temp.1, ifelse_body3, else4
+  // ifelse_body3:
+  //   c.2 = call binary+ (c.0, 2)
+  //   break if_tail5
+  // else4:
+  //   c.3 = call binary+ (c.0, 3)
+  //   break if_tail5
+  // if_tail5:
+  //   c.4 = phi | if_body1 -> c.1
+  //             | ifelse_body3 -> c.2
+  //             | else4 -> c.3
+
+
+  tree_if_clause *last = lst.back ();
+  size_t last_else = static_cast<size_t> (last->is_else_clause ());
+
+  // entry_blocks represents the block you need to enter in order to execute
+  // the condition check for the ith clause. For the else, it is simple the
+  // else body. If there is no else body, then it is padded with the tail
+  std::vector<jit_block *> entry_blocks (lst.size () + 1 - last_else);
+  std::vector<variable_map *> branch_variables (lst.size (), 0);
+  std::vector<jit_block *> branch_blocks (lst.size (), 0); // final blocks
+  entry_blocks[0] = block;
+
+  // we need to construct blocks first, because they have jumps to eachother
+  tree_if_command_list::iterator iter = lst.begin ();
+  ++iter;
+  for (size_t i = 1; iter != lst.end (); ++iter, ++i)
+    {
+      tree_if_clause *tic = *iter;
+      if (tic->is_else_clause ())
+        entry_blocks[i] = new jit_block ("else");
+      else
+        entry_blocks[i] = new jit_block ("ifelse_cond");
+      cleanup_blocks.push_back (entry_blocks[i]);
+    }
+
+  jit_block *tail = new jit_block ("if_tail");
+  if (! last_else)
+    entry_blocks[entry_blocks.size () - 1] = tail;
+
+  // actually fill out the contents of our blocks. We store the variable maps
+  // at the end of each branch, this allows us to merge them in the tail
+  variable_map *prev_map = variables;
+  iter = lst.begin ();
+  for (size_t i = 0; iter != lst.end (); ++iter, ++i)
+    {
+      tree_if_clause *tic = *iter;
+      block = entry_blocks[i];
+      assert (block);
+      variables = prev_map;
+
+      if (i) // the first block is prev_block, so it has already been added
+        blocks.push_back (entry_blocks[i]);
+
+      if (! tic->is_else_clause ())
+        {
+          tree_expression *expr = tic->condition ();
+          jit_value *cond = visit (expr);
+
+          jit_block *body = new jit_block (i == 0 ? "if_body" : "ifelse_body");
+          blocks.push_back (body);
+
+          jit_instruction *br = new jit_cond_break (cond, body,
+                                                    entry_blocks[i + 1]);
+          block->append (br);
+          block = body;
+
+          variables = new compound_map (variables);
+          branch_variables[i] = variables;
+        }
+
+      tree_statement_list *stmt_lst = tic->commands ();
+      assert (stmt_lst); // jwe: Can this be null?
+      stmt_lst->accept (*this);
+
+      branch_variables[i] = variables;
+      branch_blocks[i] = block;
+      block->append (new jit_break (tail));
+    }
+
+  blocks.push_back (tail);
+
+  // We create phi nodes in the tail to merge blocks
+  for (size_t i = 0; i < branch_variables.size () - last_else; ++i)
+    {
+      merge (tail, *prev_map, branch_blocks[i], *branch_variables[i]);
+      delete branch_variables[i];
+    }
+
+  variables = prev_map;
+  block = tail;
 }
 
 void
@@ -1281,22 +1435,28 @@
 }
 
 void
-jit_convert::merge (const variable_map& ref)
+jit_convert::merge (jit_block *merge_block, variable_map& merge_vars,
+                    jit_block *incomming_block,
+                    const variable_map& incomming_vars)
 {
-  assert (variables->size () == ref.size ());
-  variable_map::iterator viter = variables->begin ();
-  variable_map::const_iterator riter = ref.begin ();
-  for (; viter != variables->end (); ++viter, ++riter)
+  size_t merge_idx = merge_block->pred_index (incomming_block);
+  for (variable_map::const_iterator iter = incomming_vars.begin ();
+       iter != incomming_vars.end (); ++iter)
     {
-      assert (viter->first == riter->first);
-      if (viter->second != riter->second)
+      const std::string& vname = iter->first;
+      jit_value *merge_val = merge_vars.get (vname);
+      jit_value *inc_val = iter->second;
+
+      if (merge_val != inc_val)
         {
-          jit_phi *phi = new jit_phi (2);
-          phi->stash_tag (viter->first);
-          block->prepend (phi);
-          phi->stash_argument (0, riter->second);
-          phi->stash_argument (1, viter->second);
-          viter->second = phi;
+          jit_phi *phi = dynamic_cast<jit_phi *> (merge_val);
+          if (! (phi && phi->parent () == merge_block))
+            {
+              phi = merge_block->search_phi (vname, merge_val);
+              merge_vars.set (vname, phi);
+            }
+
+          phi->stash_argument (merge_idx, inc_val);
         }
     }
 }
--- a/src/pt-jit.h	Sun May 27 22:57:55 2012 -0500
+++ b/src/pt-jit.h	Mon May 28 23:19:41 2012 -0500
@@ -46,7 +46,9 @@
 // b = a + a;
 // will compile to do_binary_op (a, a).
 //
-// For loops are compiled again! Additionally, make check passes using jit.
+// For loops are compiled again!
+// if, elseif, and else statements compile again!
+// Additionally, make check passes using jit.
 //
 // The octave low level IR is a linear IR, it works by converting everything to
 // calls to jit_functions. This turns expressions like c = a + b into
@@ -56,8 +58,8 @@
 //
 //
 // TODO:
-// 1. Support if statements
-// 2. Support error cases
+// 1. Support error cases
+// 2. Support break/continue
 // 3. Fix memory leaks in JIT
 // 4. Cleanup/documentation
 // 5. ...
@@ -566,6 +568,7 @@
 private:
   jit_type *ty;
   jit_use *use_head;
+  jit_use *use_tail;
   size_t myuse_count;
 };
 
@@ -625,68 +628,82 @@
 jit_use
 {
 public:
-  jit_use (void) : used (0), next_use (0), prev_use (0) {}
+  jit_use (void) : mvalue (0), mnext (0), mprev (0), muser (0), mindex (0) {}
+
+  // we should really have a move operator, but not until c++11 :(
+  jit_use (const jit_use& use) : mvalue (0), mnext (0), mprev (0), muser (0),
+                                 mindex (0)
+  {
+    *this = use;
+  }
 
   ~jit_use (void) { remove (); }
 
-  jit_value *value (void) const { return used; }
+  jit_use& operator= (const jit_use& use)
+  {
+    stash_value (use.value (), use.user (), use.index ());
+    return *this;
+  }
 
-  size_t index (void) const { return idx; }
+  jit_value *value (void) const { return mvalue; }
 
-  jit_instruction *user (void) const { return usr; }
+  size_t index (void) const { return mindex; }
+
+  jit_instruction *user (void) const { return muser; }
 
   jit_block *user_parent (void) const;
 
-  void stash_value (jit_value *new_value, jit_instruction *u = 0,
-                    size_t use_idx = -1)
+  void stash_value (jit_value *avalue, jit_instruction *auser = 0,
+                    size_t aindex = -1)
   {
     remove ();
 
-    used = new_value;
+    mvalue = avalue;
 
-    if (used)
+    if (mvalue)
       {
-        if (used->use_head)
+        if (mvalue->use_head)
           {
-            used->use_head->prev_use = this;
-            next_use = used->use_head;
+            mvalue->use_head->mprev = this;
+            mnext = mvalue->use_head;
           }
         
-        used->use_head = this;
-        ++used->myuse_count;
+        mvalue->use_head = this;
+        ++mvalue->myuse_count;
       }
 
-    idx = use_idx;
-    usr = u;
+    mindex = aindex;
+    muser = auser;
   }
 
-  jit_use *next (void) const { return next_use; }
+  jit_use *next (void) const { return mnext; }
 
-  jit_use *prev (void) const { return prev_use; }
+  jit_use *prev (void) const { return mprev; }
 private:
   void remove (void)
   {
-    if (used)
+    if (mvalue)
       {
-        if (this == used->use_head)
-            used->use_head = next_use;
+        if (this == mvalue->use_head)
+            mvalue->use_head = mnext;
 
-        if (prev_use)
-          prev_use->next_use = next_use;
+        if (mprev)
+          mprev->mnext = mnext;
 
-        if (next_use)
-          next_use->prev_use = prev_use;
+        if (mnext)
+          mnext->mprev = mprev;
 
-        next_use = prev_use = 0;
-        --used->myuse_count;
+        mnext = mprev = 0;
+        --mvalue->myuse_count;
+        mvalue = 0;
       }
   }
 
-  jit_value *used;
-  jit_use *next_use;
-  jit_use *prev_use;
-  jit_instruction *usr;
-  size_t idx;
+  jit_value *mvalue;
+  jit_use *mnext;
+  jit_use *mprev;
+  jit_instruction *muser;
+  size_t mindex;
 };
 
 class
@@ -697,10 +714,14 @@
   jit_instruction (void) : id (next_id ()), mparent (0)
   {}
 
-  jit_instruction (size_t nargs)
+  jit_instruction (size_t nargs, jit_value *adefault = 0)
   : already_infered (nargs, reinterpret_cast<jit_type *>(0)), arguments (nargs),
     id (next_id ()), mparent (0)
-  {}
+  {
+    if (adefault)
+      for (size_t i = 0; i < nargs; ++i)
+        stash_argument (i, adefault);
+  }
 
   jit_instruction (jit_value *arg0)
     : already_infered (1, reinterpret_cast<jit_type *>(0)), arguments (1), 
@@ -772,6 +793,16 @@
     return arguments.size ();
   }
 
+  void resize_arguments (size_t acount, jit_value *adefault = 0)
+  {
+    size_t old = arguments.size ();
+    arguments.resize (acount);
+
+    if (adefault)
+      for (size_t i = old; i < acount; ++i)
+        stash_argument (i, adefault);
+  }
+
   // argument types which have been infered already
   const std::vector<jit_type *>& argument_types (void) const
   { return already_infered; }
@@ -813,7 +844,7 @@
     return ret++;
   }
 
-  std::vector<jit_use> arguments; // DO NOT resize
+  std::vector<jit_use> arguments;
 
   std::string mtag;
   size_t id;
@@ -821,6 +852,7 @@
 };
 
 class jit_terminator;
+class jit_phi;
 
 class
 jit_block : public jit_value
@@ -848,18 +880,7 @@
 
   jit_terminator *terminator (void) const;
 
-  jit_block *pred (size_t idx) const
-  {
-    // FIXME: We should probably make this O(1)
-    jit_use *puse = first_use ();
-    for (size_t i = 0; i < idx; ++i)
-      {
-        assert (puse);
-        puse = puse->next ();
-      }
-
-    return puse->user_parent ();
-  }
+  jit_block *pred (size_t idx) const;
 
   jit_terminator *pred_terminator (size_t idx) const
   {
@@ -876,7 +897,7 @@
   // takes into account for the addition of phi merges
   llvm::BasicBlock *pred_llvm (size_t idx) const
   {
-    if (mpred_llvm.size () <= idx)
+    if (mpred_llvm.size () < pred_count ())
       mpred_llvm.resize (pred_count ());
 
     return mpred_llvm[idx] ? mpred_llvm[idx] : pred (idx)->to_llvm ();
@@ -887,19 +908,7 @@
     return pred_llvm (pred_index (apred));
   }
 
-  size_t pred_index (jit_block *apred) const
-  {
-    jit_use *puse = first_use ();
-    size_t idx = 0;
-    while (puse->user_parent () != apred)
-      {
-        assert (puse);
-        puse = puse->next ();
-        ++idx;
-      }
-
-    return idx;
-  }
+  size_t pred_index (jit_block *apred) const;
 
   // create llvm phi merge blocks for all predecessors (if required)
   void create_merge (llvm::Function *inside, size_t pred_idx);
@@ -916,6 +925,10 @@
 
   const_iterator end (void) const { return instructions.begin (); }
 
+  // search for the phi function with the given tag_name, if no function
+  // exists then a new phi node is created
+  jit_phi *search_phi (const std::string& tag_name, jit_value *adefault);
+
   virtual std::ostream& print (std::ostream& os, size_t indent)
   {
     print_indent (os, indent) << mname << ":\tpred = ";
@@ -953,7 +966,8 @@
 jit_phi : public jit_instruction
 {
 public:
-  jit_phi (size_t npred) : jit_instruction (npred)
+  jit_phi (size_t npred, jit_value *adefault = 0)
+    : jit_instruction (npred, adefault)
   {}
 
   virtual bool infer (void)
@@ -1347,6 +1361,8 @@
                                                               mblock (ablock)
     {}
 
+    virtual ~variable_map () {}
+
     variable_map *parent (void) const { return mparent; }
 
     jit_block *block (void) const { return mblock; }
@@ -1419,8 +1435,8 @@
            iter != for_body->end () && dynamic_cast<jit_phi *> (*iter); ++iter)
         {
           jit_instruction *node = *iter;
-          if (! node->argument (0))
-            node->stash_argument (0, from.get (node->tag ()));
+          if (! node->argument (1))
+            node->stash_argument (1, from.get (node->tag ()));
         }
     }
   protected:
@@ -1429,7 +1445,7 @@
       jit_phi *ret = new jit_phi (2);
       ret->stash_tag (name);
       block ()->prepend (ret);
-      ret->stash_argument (1, pval);
+      ret->stash_argument (0, pval);
       return vars[name] = ret;
     }
   };
@@ -1460,6 +1476,8 @@
 
   std::list<jit_block *> blocks;
 
+  std::list<jit_block *> cleanup_blocks;
+
   std::list<jit_instruction *> worklist;
 
   std::list<jit_value *> constants;
@@ -1486,7 +1504,9 @@
 
   // place phi nodes in the current block to merge ref with variables
   // we assume the same number of deffinitions
-  void merge (const variable_map& ref);
+  void merge (jit_block *merge_block, variable_map& merge_vars,
+              jit_block *incomming_block,
+              const variable_map& incomming_vars);
 
   // this case is much simpler, just convert from the jit ir to llvm
   class