changeset 14913:c7071907a641

Use symbol_record_ref instead of names in JIT * src/pt-id.h (tree_identifier::symbol): New function. * src/symtab.h (tree_identifier::symbol_record_ref::operator->): Added const variant. * src/pt-jit.h: Use symbol_record_ref * src/pt-jit.cc: Use symbol_record_ref
author Max Brister <max@2bass.com>
date Fri, 18 May 2012 10:22:34 -0600
parents 3d3c002ccc60
children 1e5eafcb83f8
files src/pt-id.h src/pt-jit.cc src/pt-jit.h src/symtab.h
diffstat 4 files changed, 146 insertions(+), 111 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-id.h	Fri May 18 08:53:26 2012 -0600
+++ b/src/pt-id.h	Fri May 18 10:22:34 2012 -0600
@@ -114,6 +114,10 @@
 
   void accept (tree_walker& tw);
 
+  symbol_table::symbol_record_ref symbol (void) const
+  {
+    return sym;
+  }
 private:
 
   // The symbol record that this identifier references.
--- a/src/pt-jit.cc	Fri May 18 08:53:26 2012 -0600
+++ b/src/pt-jit.cc	Fri May 18 10:22:34 2012 -0600
@@ -478,7 +478,6 @@
 void
 jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv)
 {
-  // duplication here can probably be removed somehow
   if (type == any)
     to_generic (type, gv, octave_value ());
   else if (type == scalar)
@@ -557,9 +556,6 @@
 void
 jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds)
 {
-  argin.insert ("#bounds");
-  types["#bounds"] = bounds;
-
   infer_simple_for (cmd, bounds);
 }
 
@@ -690,7 +686,8 @@
 void
 jit_infer::visit_identifier (tree_identifier& ti)
 {
-  handle_identifier (ti.name (), ti.do_lookup ());
+  symbol_table::symbol_record_ref record = ti.symbol ();
+  handle_identifier (record);
 }
 
 void
@@ -853,7 +850,9 @@
           is_lvalue = true;
           rvalue_type = type_stack.back ();
           type_stack.pop_back ();
-          handle_identifier ("ans", symbol_table::varval ("ans"));
+
+          symbol_table::symbol_record_ref record (symbol_table::insert ("ans"));
+          handle_identifier (record);
 
           if (rvalue_type != type_stack.back ())
             fail ();
@@ -946,12 +945,13 @@
 }
 
 void
-jit_infer::handle_identifier (const std::string& name, octave_value v)
+jit_infer::handle_identifier (const symbol_table::symbol_record_ref& record)
 {
-  type_map::iterator iter = types.find (name);
+  type_map::iterator iter = types.find (record);
   if (iter == types.end ())
     {
-      jit_type *ty = tinfo->type_of (v);
+      jit_type *ty = tinfo->type_of (record->find ());
+      bool argin = false;
       if (is_lvalue)
         {
           if (! ty)
@@ -961,68 +961,46 @@
         {
           if (! ty)
             fail ();
-
-          argin.insert (name);
+          argin = true;
         }
 
-      types[name] = ty;
+      types[record] = type_entry (argin, ty);
       type_stack.push_back (ty);
     }
   else
-    type_stack.push_back (iter->second);
+    type_stack.push_back (iter->second.second);
 }
 
 // -------------------- jit_generator --------------------
-jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee,
-                              const std::set<std::string>& argin,
-                              const type_map& infered_types, bool have_bounds)
-  : tinfo (ti), is_lvalue (false)
+jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod,
+                              tree_simple_for_command& cmd, jit_type *bounds,
+                              const type_map& infered_types)
+  : tinfo (ti), module (mod), is_lvalue (false)
 {
-  // determine the function type through the type of all variables
-  std::vector<llvm::Type *> arg_types (infered_types.size ());
-  size_t idx = 0;
+  // create new vectors that include bounds
+  std::vector<std::string> names (infered_types.size () + 1);
+  std::vector<bool> argin (infered_types.size () + 1);
+  std::vector<jit_type *> types (infered_types.size () + 1);
+  names[0] = "#bounds";
+  argin[0] = true;
+  types[0] = bounds;
+  size_t i;
   type_map::const_iterator iter;
-  for (iter = infered_types.begin (); iter != infered_types.end (); ++iter, ++idx)
-    arg_types[idx] = iter->second->to_llvm_arg ();
-
-  // now create the LLVM function from our determined types
-  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
-  llvm::Type *tvoid = llvm::Type::getVoidTy (ctx);
-  llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false);
-  function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                                     "foobar", module);
-
-  // declare each argument and copy its initial value
-  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function);
-  builder.SetInsertPoint (body);
-  llvm::Function::arg_iterator arg_iter = function->arg_begin();
-  for (iter = infered_types.begin (); iter != infered_types.end ();
-       ++iter, ++arg_iter)
-
+  for (i = 1, iter = infered_types.begin (); iter != infered_types.end ();
+       ++i, ++iter)
     {
-      llvm::Type *vartype = iter->second->to_llvm ();
-      llvm::Value *var = builder.CreateAlloca (vartype, 0, iter->first);
-      variables[iter->first] = value (iter->second, var);
-
-      if (iter->second->force_init () || argin.count (iter->first))
-        {
-          llvm::Value *loaded_arg = builder.CreateLoad (arg_iter);
-          builder.CreateStore (loaded_arg, var);
-        }
+      names[i] = iter->first.name ();
+      argin[i] = iter->second.first;
+      types[i] = iter->second.second;
     }
 
-  // generate body
+  initialize (names, argin, types);
+
   try
     {
-      tree_simple_for_command *cmd = dynamic_cast<tree_simple_for_command*>(&tee);
-      if (have_bounds && cmd)
-        {
-          value bounds = variables["#bounds"];
-          bounds.second = builder.CreateLoad (bounds.second);
-          emit_simple_for (*cmd, bounds, true);
-        }
-      else
-        tee.accept (*this);
+      value var_bounds = variables["#bounds"];
+      var_bounds.second = builder.CreateLoad (var_bounds.second);
+      emit_simple_for (cmd, var_bounds, true);
     }
   catch (const jit_fail_exception&)
     {
@@ -1031,16 +1009,7 @@
       return;
     }
 
-  // copy computed values back into arguments
-  arg_iter = function->arg_begin ();
-  for (iter = infered_types.begin (); iter != infered_types.end ();
-       ++iter, ++arg_iter)
-    {
-      llvm::Value *var = variables[iter->first].second;
-      llvm::Value *loaded_var = builder.CreateLoad (var);
-      builder.CreateStore (loaded_var, arg_iter);
-    }
-  builder.CreateRetVoid ();
+  finalize (names);
 }
 
 void
@@ -1513,6 +1482,56 @@
   builder.CreateCall2 (ol.function, str, v.second);
 }
 
+void
+jit_generator::initialize (const std::vector<std::string>& names,
+                           const std::vector<bool>& argin,
+                           const std::vector<jit_type *> types)
+{
+  std::vector<llvm::Type *> arg_types (names.size ());
+  for (size_t i = 0; i < types.size (); ++i)
+    arg_types[i] = types[i]->to_llvm_arg ();
+
+  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
+  llvm::Type *tvoid = llvm::Type::getVoidTy (ctx);
+  llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false);
+  function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
+                                     "foobar", module);
+
+  // create variables and copy initial values
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function);
+  builder.SetInsertPoint (body);
+  llvm::Function::arg_iterator arg_iter = function->arg_begin();
+  for (size_t i = 0; i < names.size (); ++i, ++arg_iter)
+    {
+      llvm::Type *vartype = types[i]->to_llvm ();
+      const std::string& name = names[i];
+      llvm::Value *var = builder.CreateAlloca (vartype, 0, name);
+      variables[name] = value (types[i], var);
+
+      if (argin[i] || types[i]->force_init ())
+        {
+          llvm::Value *loaded_arg = builder.CreateLoad (arg_iter);
+          builder.CreateStore (loaded_arg, var);
+        }
+    }
+}
+
+void
+jit_generator::finalize (const std::vector<std::string>& names)
+{
+  // copy computed values back into arguments
+  // we use names instead of looping through variables because order is
+  // important
+  llvm::Function::arg_iterator arg_iter = function->arg_begin();
+  for (size_t i = 0; i < names.size (); ++i, ++arg_iter)
+    {
+      llvm::Value *var = variables[names[i]].second;
+      llvm::Value *loaded_var = builder.CreateLoad (var);
+      builder.CreateStore (loaded_var, arg_iter);
+    }
+  builder.CreateRetVoid ();
+}
+
 // -------------------- tree_jit --------------------
 
 tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0)
@@ -1584,7 +1603,8 @@
 // -------------------- jit_info --------------------
 jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd,
                     jit_type *bounds) : tinfo (tjit.get_typeinfo ()),
-                                        engine (tjit.get_engine ())
+                                        engine (tjit.get_engine ()),
+                                        bounds_t (bounds)
 {
   jit_infer infer(tinfo);
 
@@ -1598,10 +1618,9 @@
       return;
     }
 
-  argin = infer.get_argin ();
   types = infer.get_types ();
 
-  jit_generator gen(tinfo, tjit.get_module (), cmd, argin, types);
+  jit_generator gen(tinfo, tjit.get_module (), cmd, bounds, types);
   function = gen.get_function ();
 
   if (function)
@@ -1635,31 +1654,29 @@
   if (! function)
     return false;
 
-  std::vector<llvm::GenericValue> args (types.size ());
+  std::vector<llvm::GenericValue> args (types.size () + 1);
+  tinfo->to_generic (bounds_t, args[0], bounds);
+
   size_t idx;
   type_map::const_iterator iter;
-  for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx)
+  for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx)
     {
-      if (argin.count (iter->first))
+      if (iter->second.first) // argin?
         {
-          octave_value ov;
-          if (iter->first == "#bounds")
-            ov = bounds;
-          else
-            ov = symbol_table::varval (iter->first);
-
-          tinfo->to_generic (iter->second, args[idx], ov);
+          octave_value ov = iter->first->varval ();
+          tinfo->to_generic (iter->second.second, args[idx], ov);
         }
       else
-        tinfo->to_generic (iter->second, args[idx]);
+        tinfo->to_generic (iter->second.second, args[idx]);
     }
 
   engine->runFunction (function, args);
 
-  for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx)
+  for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx)
     {
-      octave_value result = tinfo->to_octave_value (iter->second, args[idx]);
-      symbol_table::varref (iter->first) = result;
+      octave_value result = tinfo->to_octave_value (iter->second.second, args[idx]);
+      octave_value &ref = iter->first->varref ();
+      ref = result;
     }
 
   tinfo->reset_generic ();
@@ -1670,19 +1687,20 @@
 bool
 jit_info::match () const
 {
-  for (std::set<std::string>::iterator iter = argin.begin ();
-       iter != argin.end (); ++iter)
+  for (type_map::const_iterator iter = types.begin (); iter != types.end ();
+       ++iter)
+       
     {
-      if (*iter == "#bounds")
-        continue;
+      if (iter->second.first) // argin?
+        {
+          jit_type *required_type = iter->second.second;
+          octave_value val = iter->first->varval ();
+          jit_type *current_type = tinfo->type_of (val);
 
-      jit_type *required_type = types.find (*iter)->second;
-      octave_value val = symbol_table::varref (*iter);
-      jit_type *current_type = tinfo->type_of (val);
-
-      // FIXME: should be: ! required_type->is_parent (current_type)
-      if (required_type != current_type)
-        return false;
+          // FIXME: should be: ! required_type->is_parent (current_type)
+          if (required_type != current_type)
+            return false;
+        }
     }
 
   return true;
--- a/src/pt-jit.h	Fri May 18 08:53:26 2012 -0600
+++ b/src/pt-jit.h	Fri May 18 10:22:34 2012 -0600
@@ -32,6 +32,7 @@
 #include "Array.h"
 #include "Range.h"
 #include "pt-walk.h"
+#include "symtab.h"
 
 // -------------------- Current status --------------------
 // Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized.
@@ -295,8 +296,6 @@
 
   void reset_generic (void);
 private:
-  typedef std::map<std::string, jit_type *> type_map;
-
   jit_type *new_type (const std::string& name, bool force_init,
                       jit_type *parent, llvm::Type *llvm_type);
 
@@ -332,14 +331,16 @@
 class
 jit_infer : public tree_walker
 {
-  typedef std::map<std::string, jit_type *> type_map;
 public:
+  // pair <argin, type>
+  typedef std::pair<bool, jit_type *> type_entry;
+  typedef std::map<symbol_table::symbol_record_ref, type_entry,
+                   symbol_table::symbol_record_ref::comparator> type_map;
+
   jit_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false),
                                   rvalue_type (0)
   {}
 
-  const std::set<std::string>& get_argin () const { return argin; }
-
   const type_map& get_types () const { return types; }
 
   void infer (tree_simple_for_command& cmd, jit_type *bounds);
@@ -433,7 +434,7 @@
   void infer_simple_for (tree_simple_for_command& cmd,
                          jit_type *bounds);
 
-  void handle_identifier (const std::string& name, octave_value v);
+  void handle_identifier (const symbol_table::symbol_record_ref& record);
 
   jit_typeinfo *tinfo;
 
@@ -441,7 +442,6 @@
   jit_type *rvalue_type;
 
   type_map types;
-  std::set<std::string> argin;
 
   std::vector<jit_type *> type_stack;
 };
@@ -449,11 +449,11 @@
 class
 jit_generator : public tree_walker
 {
-  typedef std::map<std::string, jit_type *> type_map;
 public:
-  jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee,
-                 const std::set<std::string>& argin,
-                 const type_map& infered_types, bool have_bounds = true);
+  typedef jit_infer::type_map type_map;
+
+  jit_generator (jit_typeinfo *ti, llvm::Module *mod, tree_simple_for_command &cmd,
+                 jit_type *bounds, const type_map& infered_types);
 
   llvm::Function *get_function () const { return function; }
 
@@ -555,7 +555,14 @@
     value_stack.push_back (value (type, v));
   }
 
+  void initialize (const std::vector<std::string>& names,
+                   const std::vector<bool>& argin,
+                   const std::vector<jit_type *> types);
+
+  void finalize (const std::vector<std::string>& names);
+
   jit_typeinfo *tinfo;
+  llvm::Module *module;
   llvm::Function *function;
 
   bool is_lvalue;
@@ -596,19 +603,19 @@
 jit_info
 {
 public:
+  typedef jit_infer::type_map type_map;
+
   jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds);
 
   bool execute (const octave_value& bounds) const;
 
   bool match (void) const;
 private:
-  typedef std::map<std::string, jit_type *> type_map;
-  
   jit_typeinfo *tinfo;
   llvm::ExecutionEngine *engine;
-  std::set<std::string> argin;
   type_map types;
   llvm::Function *function;
+  jit_type *bounds_t;
 };
 
 #endif
--- a/src/symtab.h	Fri May 18 08:53:26 2012 -0600
+++ b/src/symtab.h	Fri May 18 10:22:34 2012 -0600
@@ -610,6 +610,12 @@
       return &sym;
     }
 
+    symbol_record *operator-> (void) const
+    {
+      update ();
+      return &sym;
+    }
+
     // can be used to place symbol_record_ref in maps, we don't overload < as
     // it doesn't make any sense for symbol_record_ref
     struct comparator
@@ -621,7 +627,7 @@
       }
     };
   private:
-    void update (void)
+    void update (void) const
     {
       scope_id curr_scope = symbol_table::current_scope ();
       if (scope != curr_scope || ! sym.is_valid ())
@@ -631,8 +637,8 @@
         }
     }
 
-    scope_id scope;
-    symbol_record sym;
+    mutable scope_id scope;
+    mutable symbol_record sym;
   };
 
   class