diff libinterp/interp-core/pt-jit.cc @ 15337:3f43e9d6d86e

JIT compile anonymous functions * jit-ir.h (jit_block::front, jit_block::back): New function. (jit_call::jit_call): New overloads. (jit_return): New class. * jit-typeinfo.cc (octave_jit_create_undef): New function. (jit_operation::to_idx): Correctly handle empty type vector. (jit_typeinfo::jit_typeinfo): Add destroy_fn and initialize create_undef. * jit-typeinfo.h (jit_typeinfo::get_any_ptr, jit_typeinfo::destroy, jit_typeinfo::create_undef): New function. * pt-jit.cc (jit_convert::jit_convert): Add overload and refactor. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute, jit_function_info::jit_function_info, jit_function_info::execute, jit_function_info::match): New function. (jit_convert::get_variable): Support function variable lookup. (jit_convert_llvm::convert): Handle loop/function agnostic stuff. (jit_convert_llvm::visit): Handle function creation as well. (tree_jit::execute): Move implementation to tree_jit::do_execute. (jit_info::compile): Call convert_loop instead of convert. * pt-jit.h (jit_convert::jit_convert): New overload. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute): New function. (jit_convert::create_variable, jit_convert_llvm::initialize): Update signature. (tree_jit::execute): Made static. (tree_jit::tree_jit): Made private. (jit_function_info): New class. * ov-usr-fcn.cc (octave_user_function::~octave_user_function): Delete jit_info. (octave_user_function::octave_user_function): Maybe JIT and use is_special_expr and special_expr. (octave_user_function::special_expr): New function. * ov-usr-fcn.h (octave_user_function::is_special_expr, octave_user_function::special_expr, octave_user_function::get_info, octave_user_function::stash_info): New function. * pt-decl.h (tree_decl_elt::name): New function. * pt-eval.cc (tree_evaluator::visit_simple_for_command, tree_evaluator::visit_while_command): Use static tree_jit methods.
author Max Brister <max@2bass.com>
date Sun, 09 Sep 2012 00:29:00 -0600
parents 8125773322d4
children b49d707fe9d7
line wrap: on
line diff
--- a/libinterp/interp-core/pt-jit.cc	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/interp-core/pt-jit.cc	Sun Sep 09 00:29:00 2012 -0600
@@ -65,22 +65,16 @@
 
 // -------------------- jit_convert --------------------
 jit_convert::jit_convert (tree &tee, jit_type *for_bounds)
-  : iterator_count (0), for_bounds_count (0), short_count (0), breaking (false)
+  : converting_function (false)
 {
-  jit_instruction::reset_ids ();
-
-  entry_block = factory.create<jit_block> ("body");
-  final_block = factory.create<jit_block> ("final");
-  blocks.push_back (entry_block);
-  entry_block->mark_alive ();
-  block = entry_block;
+  initialize (symbol_table::current_scope ());
 
   if (for_bounds)
     create_variable (next_for_bounds (false), for_bounds);
 
   visit (tee);
 
-  // FIXME: Remove if we no longer only compile loops
+  // breaks must have been handled by the top level loop
   assert (! breaking);
   assert (breaks.empty ());
   assert (continues.empty ());
@@ -95,6 +89,91 @@
       if (name.size () && name[0] != '#')
         final_block->append (factory.create<jit_store_argument> (var));
     }
+
+  final_block->append (factory.create<jit_return> ());
+}
+
+jit_convert::jit_convert (octave_user_function& fcn,
+                          const std::vector<jit_type *>& args)
+  : converting_function (true)
+{
+  initialize (fcn.scope ());
+
+  tree_parameter_list *plist = fcn.parameter_list ();
+  tree_parameter_list *rlist = fcn.return_list ();
+  if (plist && plist->takes_varargs ())
+    throw jit_fail_exception ("varags not supported");
+
+  if (rlist && (rlist->size () > 1 || rlist->takes_varargs ()))
+    throw jit_fail_exception ("multiple returns not supported");
+
+  if (plist)
+    {
+      tree_parameter_list::iterator piter = plist->begin ();
+      for (size_t i = 0; i < args.size (); ++i, ++piter)
+        {
+          if (piter == plist->end ())
+            throw jit_fail_exception ("Too many parameter to function");
+
+          tree_decl_elt *elt = *piter;
+          std::string name = elt->name ();
+          create_variable (name, args[i]);
+        }
+    }
+
+  jit_value *return_value = 0;
+  if (fcn.is_special_expr ())
+    {
+      tree_expression *expr = fcn.special_expr ();
+      if (expr)
+        {
+          jit_variable *retvar = get_variable ("#return");
+          jit_value *retval = visit (expr);
+          block->append (factory.create<jit_assign> (retvar, retval));
+          return_value = retvar;
+        }
+    }
+  else
+    visit_statement_list (*fcn.body ());
+
+  // the user may use break or continue to exit the function. Because the
+  // function does not start as a loop, we can have one continue, one break, or
+  // a regular fallthrough to exit the function
+  if (continues.size ())
+    {
+      assert (! continues.size ());
+      finish_breaks (final_block, continues);
+    }
+  else if (breaks.size ())
+    finish_breaks (final_block, breaks);
+  else
+    block->append (factory.create<jit_branch> (final_block));
+  blocks.push_back (final_block);
+  block = final_block;
+
+  if (! return_value && rlist && rlist->size () == 1)
+    {
+      tree_decl_elt *elt = rlist->front ();
+      return_value = get_variable (elt->name ());
+    }
+
+  // FIXME: We should use live range analysis to delete variables where needed.
+  // For now we just delete everything at the end of the function.
+  for (variable_map::iterator iter = vmap.begin (); iter != vmap.end (); ++iter)
+    {
+      if (iter->second != return_value)
+        {
+          jit_call *call;
+          call = factory.create<jit_call> (&jit_typeinfo::destroy,
+                                           iter->second);
+          final_block->append (call);
+        }
+    }
+
+  if (return_value)
+    final_block->append (factory.create<jit_return> (return_value));
+  else
+    final_block->append (factory.create<jit_return> ());
 }
 
 void
@@ -719,6 +798,23 @@
   throw jit_fail_exception ();
 }
 
+void
+jit_convert::initialize (symbol_table::scope_id s)
+{
+  scope = s;
+  iterator_count = 0;
+  for_bounds_count = 0;
+  short_count = 0;
+  breaking = false;
+  jit_instruction::reset_ids ();
+
+  entry_block = factory.create<jit_block> ("body");
+  final_block = factory.create<jit_block> ("final");
+  blocks.push_back (entry_block);
+  entry_block->mark_alive ();
+  block = entry_block;
+}
+
 jit_call *
 jit_convert::create_checked_impl (jit_call *ret)
 {
@@ -749,20 +845,42 @@
   if (ret)
     return ret;
 
-  octave_value val = symbol_table::find (vname);
-  jit_type *type = jit_typeinfo::type_of (val);
-  bounds.push_back (type_bound (type, vname));
+  symbol_table::symbol_record record = symbol_table::find_symbol (vname, scope);
+  if (record.is_persistent () || record.is_global ())
+    throw jit_fail_exception ("Persistent and global not yet supported");
 
-  return create_variable (vname, type);
+  if (converting_function)
+    return create_variable (vname, jit_typeinfo::get_any (), false);
+  else
+    {
+      octave_value val = record.varval ();
+      jit_type *type = jit_typeinfo::type_of (val);
+      bounds.push_back (type_bound (type, vname));
+
+      return create_variable (vname, type);
+    }
 }
 
 jit_variable *
-jit_convert::create_variable (const std::string& vname, jit_type *type)
+jit_convert::create_variable (const std::string& vname, jit_type *type,
+                              bool isarg)
 {
   jit_variable *var = factory.create<jit_variable> (vname);
-  jit_extract_argument *extract;
-  extract = factory.create<jit_extract_argument> (type, var);
-  entry_block->prepend (extract);
+
+  if (isarg)
+    {
+      jit_extract_argument *extract;
+      extract = factory.create<jit_extract_argument> (type, var);
+      entry_block->prepend (extract);
+    }
+  else
+    {
+      jit_call *init = factory.create<jit_call> (&jit_typeinfo::create_undef);
+      jit_assign *assign = factory.create<jit_assign> (var, init);
+      entry_block->prepend (assign);
+      entry_block->prepend (init);
+    }
+
   return vmap[vname] = var;
 }
 
@@ -898,10 +1016,12 @@
 
 // -------------------- jit_convert_llvm --------------------
 llvm::Function *
-jit_convert_llvm::convert (llvm::Module *module,
-                           const jit_block_list& blocks,
-                           const std::list<jit_value *>& constants)
+jit_convert_llvm::convert_loop (llvm::Module *module,
+                                const jit_block_list& blocks,
+                                const std::list<jit_value *>& constants)
 {
+  converting_function = false;
+
   // for now just init arguments from entry, later we will have to do something
   // more interesting
   jit_block *entry_block = blocks.front ();
@@ -934,44 +1054,7 @@
           arguments[argument_vec[i].first] = loaded_arg;
         }
 
-      std::list<jit_block *>::const_iterator biter;
-      for (biter = blocks.begin (); biter != blocks.end (); ++biter)
-        {
-          jit_block *jblock = *biter;
-          llvm::BasicBlock *block = llvm::BasicBlock::Create (context,
-                                                              jblock->name (),
-                                                              function);
-          jblock->stash_llvm (block);
-        }
-
-      jit_block *first = *blocks.begin ();
-      builder.CreateBr (first->to_llvm ());
-
-      // constants aren't in the IR, we visit those first
-      for (std::list<jit_value *>::const_iterator iter = constants.begin ();
-           iter != constants.end (); ++iter)
-        if (! isa<jit_instruction> (*iter))
-          visit (*iter);
-
-      // convert all instructions
-      for (biter = blocks.begin (); biter != blocks.end (); ++biter)
-        visit (*biter);
-
-      // now finish phi nodes
-      for (biter = blocks.begin (); biter != blocks.end (); ++biter)
-        {
-          jit_block& block = **biter;
-          for (jit_block::iterator piter = block.begin ();
-               piter != block.end () && isa<jit_phi> (*piter); ++piter)
-            {
-              jit_instruction *phi = *piter;
-              finish_phi (static_cast<jit_phi *> (phi));
-            }
-        }
-
-      jit_block *last = blocks.back ();
-      builder.SetInsertPoint (last->to_llvm ());
-      builder.CreateRetVoid ();
+      convert (blocks, constants);
     } catch (const jit_fail_exception& e)
     {
       function->eraseFromParent ();
@@ -981,6 +1064,92 @@
   return function;
 }
 
+
+jit_function
+jit_convert_llvm::convert_function (llvm::Module *module,
+                                    const jit_block_list& blocks,
+                                    const std::list<jit_value *>& constants,
+                                    octave_user_function& fcn,
+                                    const std::vector<jit_type *>& args)
+{
+  converting_function = true;
+
+  jit_block *final_block = blocks.back ();
+  jit_return *ret = dynamic_cast<jit_return *> (final_block->back ());
+  assert (ret);
+
+  jit_function creating = jit_function (module, jit_convention::internal,
+                                        "foobar", ret->result_type (), args);
+  function = creating.to_llvm ();
+
+  try
+    {
+      prelude = creating.new_block ("prelude");
+      builder.SetInsertPoint (prelude);
+
+      tree_parameter_list *plist = fcn.parameter_list ();
+      if (plist)
+        {
+          tree_parameter_list::iterator piter = plist->begin ();
+          tree_parameter_list::iterator pend = plist->end ();
+          for (size_t i = 0; i < args.size () && piter != pend; ++i, ++piter)
+            {
+              tree_decl_elt *elt = *piter;
+              std::string arg_name = elt->name ();
+              arguments[arg_name] = creating.argument (builder, i);
+            }
+        }
+
+      convert (blocks, constants);
+    } catch (const jit_fail_exception& e)
+    {
+      function->eraseFromParent ();
+      throw;
+    }
+
+  return creating;
+}
+
+void
+jit_convert_llvm::convert (const jit_block_list& blocks,
+                           const std::list<jit_value *>& constants)
+{
+  std::list<jit_block *>::const_iterator biter;
+  for (biter = blocks.begin (); biter != blocks.end (); ++biter)
+    {
+      jit_block *jblock = *biter;
+      llvm::BasicBlock *block = llvm::BasicBlock::Create (context,
+                                                          jblock->name (),
+                                                          function);
+      jblock->stash_llvm (block);
+    }
+
+  jit_block *first = *blocks.begin ();
+  builder.CreateBr (first->to_llvm ());
+
+  // constants aren't in the IR, we visit those first
+  for (std::list<jit_value *>::const_iterator iter = constants.begin ();
+       iter != constants.end (); ++iter)
+    if (! isa<jit_instruction> (*iter))
+      visit (*iter);
+
+  // convert all instructions
+  for (biter = blocks.begin (); biter != blocks.end (); ++biter)
+    visit (*biter);
+
+  // now finish phi nodes
+  for (biter = blocks.begin (); biter != blocks.end (); ++biter)
+    {
+      jit_block& block = **biter;
+      for (jit_block::iterator piter = block.begin ();
+           piter != block.end () && isa<jit_phi> (*piter); ++piter)
+        {
+          jit_instruction *phi = *piter;
+          finish_phi (static_cast<jit_phi *> (phi));
+        }
+    }
+}
+
 void
 jit_convert_llvm::finish_phi (jit_phi *phi)
 {
@@ -1089,10 +1258,16 @@
 {
   llvm::Value *arg = arguments[extract.name ()];
   assert (arg);
-  arg = builder.CreateLoad (arg);
 
-  const jit_function& ol = extract.overload ();
-  extract.stash_llvm (ol.call (builder, arg));
+  if (converting_function)
+    extract.stash_llvm (arg);
+  else
+    {
+      arg = builder.CreateLoad (arg);
+
+      const jit_function& ol = extract.overload ();
+      extract.stash_llvm (ol.call (builder, arg));
+    }
 }
 
 void
@@ -1105,6 +1280,16 @@
 }
 
 void
+jit_convert_llvm::visit (jit_return& ret)
+{
+  jit_value *res = ret.result ();
+  if (res)
+    builder.CreateRet (res->to_llvm ());
+  else
+    builder.CreateRetVoid ();
+}
+
+void
 jit_convert_llvm::visit (jit_phi& phi)
 {
   // we might not have converted all incoming branches, so we don't
@@ -1539,44 +1724,27 @@
 bool
 tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds)
 {
-  const size_t MIN_TRIP_COUNT = 1000;
-
-  size_t tc = trip_count (bounds);
-  if (! tc || ! initialize ())
-    return false;
-
-  jit_info::vmap extra_vars;
-  extra_vars["#for_bounds0"] = &bounds;
-
-  jit_info *info = cmd.get_info ();
-  if (! info || ! info->match (extra_vars))
-    {
-      if (tc < MIN_TRIP_COUNT)
-        return false;
-
-      delete info;
-      info = new jit_info (*this, cmd, bounds);
-      cmd.stash_info (info);
-    }
-
-  return info->execute (extra_vars);
+  return instance ().do_execute (cmd, bounds);
 }
 
 bool
 tree_jit::execute (tree_while_command& cmd)
 {
-  if (! initialize ())
-    return false;
+  return instance ().do_execute (cmd);
+}
 
-  jit_info *info = cmd.get_info ();
-  if (! info || ! info->match ())
-    {
-      delete info;
-      info = new jit_info (*this, cmd);
-      cmd.stash_info (info);
-    }
+bool
+tree_jit::execute (octave_user_function& fcn, const octave_value_list& args,
+                   octave_value_list& retval)
+{
+  return instance ().do_execute (fcn, args, retval);
+}
 
-  return info->execute ();
+tree_jit&
+tree_jit::instance (void)
+{
+  static tree_jit ret;
+  return ret;
 }
 
 bool
@@ -1616,6 +1784,67 @@
   return true;
 }
 
+bool
+tree_jit::do_execute (tree_simple_for_command& cmd, const octave_value& bounds)
+{
+  const size_t MIN_TRIP_COUNT = 1000;
+
+  size_t tc = trip_count (bounds);
+  if (! tc || ! initialize ())
+    return false;
+
+  jit_info::vmap extra_vars;
+  extra_vars["#for_bounds0"] = &bounds;
+
+  jit_info *info = cmd.get_info ();
+  if (! info || ! info->match (extra_vars))
+    {
+      if (tc < MIN_TRIP_COUNT)
+        return false;
+
+      delete info;
+      info = new jit_info (*this, cmd, bounds);
+      cmd.stash_info (info);
+    }
+
+  return info->execute (extra_vars);
+}
+
+bool
+tree_jit::do_execute (tree_while_command& cmd)
+{
+  if (! initialize ())
+    return false;
+
+  jit_info *info = cmd.get_info ();
+  if (! info || ! info->match ())
+    {
+      delete info;
+      info = new jit_info (*this, cmd);
+      cmd.stash_info (info);
+    }
+
+  return info->execute ();
+}
+
+bool
+tree_jit::do_execute (octave_user_function& fcn, const octave_value_list& args,
+                      octave_value_list& retval)
+{
+  if (! initialize ())
+    return false;
+
+  jit_function_info *info = fcn.get_info ();
+    if (! info || ! info->match (args))
+      {
+        delete info;
+        info = new jit_function_info (*this, fcn, args);
+        fcn.stash_info (info);
+      }
+
+    return info->execute (args, retval);
+}
+
 size_t
 tree_jit::trip_count (const octave_value& bounds) const
 {
@@ -1644,6 +1873,163 @@
 #endif
 }
 
+// -------------------- jit_function_info --------------------
+jit_function_info::jit_function_info (tree_jit& tjit,
+                                      octave_user_function& fcn,
+                                      const octave_value_list& ov_args)
+  : argument_types (ov_args.length ()), function (0)
+{
+  size_t nargs = ov_args.length ();
+  for (size_t i = 0; i < nargs; ++i)
+    argument_types[i] = jit_typeinfo::type_of (ov_args(i));
+
+  try
+    {
+      jit_convert conv (fcn, argument_types);
+      jit_infer infer (conv.get_factory (), conv.get_blocks (),
+                       conv.get_variable_map ());
+      infer.infer ();
+
+#if OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          jit_block_list& blocks = infer.get_blocks ();
+          jit_block *entry_block = blocks.front ();
+          entry_block->label ();
+          std::cout << "-------------------- Compiling function ";
+          std::cout << "--------------------\n";
+
+          tree_print_code tpc (std::cout);
+          tpc.visit_octave_user_function_header (fcn);
+          tpc.visit_statement_list (*fcn.body ());
+          tpc.visit_octave_user_function_trailer (fcn);
+          blocks.print (std::cout, "octave jit ir");
+        }
+#endif
+
+      jit_factory& factory = conv.get_factory ();
+      llvm::Module *module = tjit.get_module ();
+      jit_convert_llvm to_llvm;
+      jit_function raw_fn = to_llvm.convert_function (module,
+                                                      infer.get_blocks (),
+                                                      factory.constants (),
+                                                      fcn, argument_types);
+
+#ifdef OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          std::cout << "-------------------- raw function ";
+          std::cout << "--------------------\n";
+          std::cout << *raw_fn.to_llvm () << std::endl;
+        }
+#endif
+
+      std::string wrapper_name = fcn.name () + "_wrapper";
+      jit_type *any_t = jit_typeinfo::get_any ();
+      std::vector<jit_type *> wrapper_args (1, jit_typeinfo::get_any_ptr ());
+      jit_function wrapper (module, jit_convention::internal, wrapper_name,
+                            any_t, wrapper_args);
+      llvm::BasicBlock *wrapper_body = wrapper.new_block ();
+      builder.SetInsertPoint (wrapper_body);
+
+      llvm::Value *wrapper_arg = wrapper.argument (builder, 0);
+      std::vector<llvm::Value *> raw_args (nargs);
+      for (size_t i = 0; i < nargs; ++i)
+        {
+          llvm::Value *arg;
+          arg = builder.CreateConstInBoundsGEP1_32 (wrapper_arg, i);
+          arg = builder.CreateLoad (arg);
+
+          jit_type *arg_type = argument_types[i];
+          const jit_function& cast = jit_typeinfo::cast (arg_type, any_t);
+          raw_args[i] = cast.call (builder, arg);
+        }
+
+      llvm::Value *result = raw_fn.call (builder, raw_args);
+      if (raw_fn.result ())
+        {
+          jit_type *raw_result_t = raw_fn.result ();
+          const jit_function& cast = jit_typeinfo::cast (any_t, raw_result_t);
+          result = cast.call (builder, result);
+        }
+      else
+        {
+          llvm::Value *zero = builder.getInt32 (0);
+          result = builder.CreateBitCast (zero, any_t->to_llvm ());
+        }
+
+      wrapper.do_return (builder, result);
+
+      llvm::Function *llvm_function = wrapper.to_llvm ();
+      tjit.optimize (llvm_function);
+
+#ifdef OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          std::cout << "-------------------- optimized and wrapped ";
+          std::cout << "--------------------\n";
+          std::cout << *llvm_function << std::endl;
+        }
+#endif
+
+      llvm::ExecutionEngine* engine = tjit.get_engine ();
+      void *void_fn = engine->getPointerToFunction (llvm_function);
+      function = reinterpret_cast<jited_function> (void_fn);
+    }
+  catch (const jit_fail_exception& e)
+    {
+      argument_types.clear ();
+#ifdef OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          if (e.known ())
+            std::cout << "jit fail: " << e.what () << std::endl;
+        }
+#endif
+    }
+}
+
+bool
+jit_function_info::execute (const octave_value_list& ov_args,
+                            octave_value_list& retval) const
+{
+  if (! function)
+    return false;
+
+  // TODO figure out a way to delete ov_args so we avoid duplicating refcount
+  size_t nargs = ov_args.length ();
+  std::vector<octave_base_value *> args (nargs);
+  for (size_t i = 0; i < nargs; ++i)
+    {
+      octave_base_value *obv = ov_args(i).internal_rep ();
+      obv->grab ();
+      args[i] = obv;
+    }
+
+  octave_base_value *ret = function (&args[0]);
+  if (ret)
+    retval(0) = octave_value (ret);
+
+  return true;
+}
+
+bool
+jit_function_info::match (const octave_value_list& ov_args) const
+{
+  if (! function)
+    return true;
+
+  size_t nargs = ov_args.length ();
+  if (nargs != argument_types.size ())
+    return false;
+
+  for (size_t i = 0; i < nargs; ++i)
+    if (jit_typeinfo::type_of (ov_args(i)) != argument_types[i])
+      return false;
+
+  return true;
+}
+
 // -------------------- jit_info --------------------
 jit_info::jit_info (tree_jit& tjit, tree& tee)
   : engine (tjit.get_engine ()), function (0), llvm_function (0)
@@ -1739,8 +2125,9 @@
 
       jit_factory& factory = conv.get_factory ();
       jit_convert_llvm to_llvm;
-      llvm_function = to_llvm.convert (tjit.get_module (), infer.get_blocks (),
-                                       factory.constants ());
+      llvm_function = to_llvm.convert_loop (tjit.get_module (),
+                                            infer.get_blocks (),
+                                            factory.constants ());
       arguments = to_llvm.get_arguments ();
       bounds = conv.get_bounds ();
     }
@@ -2126,4 +2513,13 @@
 
 %!error <undefined near> (test_undef);
 
+%!shared id
+%! id = @(x) x;
+
+%!assert (id (1), 1);
+%!assert (id (1+1i), 1+1i)
+%!assert (id (1, 2), 1)
+%!error <undefined> (id ())
+
+
 */