diff libinterp/interp-core/pt-jit.h @ 15337:3f43e9d6d86e

JIT compile anonymous functions * jit-ir.h (jit_block::front, jit_block::back): New function. (jit_call::jit_call): New overloads. (jit_return): New class. * jit-typeinfo.cc (octave_jit_create_undef): New function. (jit_operation::to_idx): Correctly handle empty type vector. (jit_typeinfo::jit_typeinfo): Add destroy_fn and initialize create_undef. * jit-typeinfo.h (jit_typeinfo::get_any_ptr, jit_typeinfo::destroy, jit_typeinfo::create_undef): New function. * pt-jit.cc (jit_convert::jit_convert): Add overload and refactor. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute, jit_function_info::jit_function_info, jit_function_info::execute, jit_function_info::match): New function. (jit_convert::get_variable): Support function variable lookup. (jit_convert_llvm::convert): Handle loop/function agnostic stuff. (jit_convert_llvm::visit): Handle function creation as well. (tree_jit::execute): Move implementation to tree_jit::do_execute. (jit_info::compile): Call convert_loop instead of convert. * pt-jit.h (jit_convert::jit_convert): New overload. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute): New function. (jit_convert::create_variable, jit_convert_llvm::initialize): Update signature. (tree_jit::execute): Made static. (tree_jit::tree_jit): Made private. (jit_function_info): New class. * ov-usr-fcn.cc (octave_user_function::~octave_user_function): Delete jit_info. (octave_user_function::octave_user_function): Maybe JIT and use is_special_expr and special_expr. (octave_user_function::special_expr): New function. * ov-usr-fcn.h (octave_user_function::is_special_expr, octave_user_function::special_expr, octave_user_function::get_info, octave_user_function::stash_info): New function. * pt-decl.h (tree_decl_elt::name): New function. * pt-eval.cc (tree_evaluator::visit_simple_for_command, tree_evaluator::visit_while_command): Use static tree_jit methods.
author Max Brister <max@2bass.com>
date Sun, 09 Sep 2012 00:29:00 -0600
parents 8125773322d4
children 8355fddce815
line wrap: on
line diff
--- a/libinterp/interp-core/pt-jit.h	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/interp-core/pt-jit.h	Sun Sep 09 00:29:00 2012 -0600
@@ -26,8 +26,10 @@
 #ifdef HAVE_LLVM
 
 #include "jit-ir.h"
+#include "pt-walk.h"
+#include "symtab.h"
 
-#include "pt-walk.h"
+class octave_value_list;
 
 // Convert from the parse tree (AST) to the low level Octave IR.
 class
@@ -40,6 +42,8 @@
 
   jit_convert (tree &tee, jit_type *for_bounds = 0);
 
+  jit_convert (octave_user_function& fcn, const std::vector<jit_type *>& args);
+
 #define DECL_ARG(n) const ARG ## n& arg ## n
 #define JIT_CREATE_CHECKED(N)                                           \
   template <OCT_MAKE_DECL_LIST (typename, ARG, N)>                      \
@@ -156,6 +160,11 @@
   std::vector<std::pair<std::string, bool> > arguments;
   type_bound_vector bounds;
 
+  bool converting_function;
+
+  // the scope of the function we are converting, or the current scope
+  symbol_table::scope_id scope;
+
   jit_factory factory;
 
   // used instead of return values from visit_* functions
@@ -179,6 +188,8 @@
 
   variable_map vmap;
 
+  void initialize (symbol_table::scope_id s);
+
   jit_call *create_checked_impl (jit_call *ret);
 
   // get an existing vairable. If the variable does not exist, it will not be
@@ -191,7 +202,8 @@
 
   // create a variable of the given name and given type. Will also insert an
   // extract statement
-  jit_variable *create_variable (const std::string& vname, jit_type *type);
+  jit_variable *create_variable (const std::string& vname, jit_type *type,
+                                 bool isarg = true);
 
   // The name of the next for loop iterator. If inc is false, then the iterator
   // counter will not be incremented.
@@ -233,10 +245,17 @@
 jit_convert_llvm : public jit_ir_walker
 {
 public:
-  llvm::Function *convert (llvm::Module *module,
-                           const jit_block_list& blocks,
-                           const std::list<jit_value *>& constants);
+  llvm::Function *convert_loop (llvm::Module *module,
+                                const jit_block_list& blocks,
+                                const std::list<jit_value *>& constants);
 
+  jit_function convert_function (llvm::Module *module,
+                                 const jit_block_list& blocks,
+                                 const std::list<jit_value *>& constants,
+                                 octave_user_function& fcn,
+                                 const std::vector<jit_type *>& args);
+
+  // arguments to the llvm::Function for loops
   const std::vector<std::pair<std::string, bool> >& get_arguments(void) const
   { return argument_vec; }
 
@@ -247,13 +266,22 @@
 
 #undef JIT_METH
 private:
+  // name -> argument index (used for compiling functions)
+  std::map<std::string, int> argument_index;
+
   std::vector<std::pair<std::string, bool> > argument_vec;
 
-  // name -> llvm argument
+  // name -> llvm argument (used for compiling loops)
   std::map<std::string, llvm::Value *> arguments;
+
+  bool converting_function;
+
   llvm::Function *function;
   llvm::BasicBlock *prelude;
 
+  void convert (const jit_block_list& blocks,
+                const std::list<jit_value *>& constants);
+
   void finish_phi (jit_phi *phi);
 
   void visit (jit_value *jvalue)
@@ -319,13 +347,15 @@
 tree_jit
 {
 public:
-  tree_jit (void);
-
   ~tree_jit (void);
 
-  bool execute (tree_simple_for_command& cmd, const octave_value& bounds);
+  static bool execute (tree_simple_for_command& cmd,
+                       const octave_value& bounds);
 
-  bool execute (tree_while_command& cmd);
+  static bool execute (tree_while_command& cmd);
+
+  static bool execute (octave_user_function& fcn, const octave_value_list& args,
+                       octave_value_list& retval);
 
   llvm::ExecutionEngine *get_engine (void) const { return engine; }
 
@@ -333,8 +363,19 @@
 
   void optimize (llvm::Function *fn);
  private:
+  tree_jit (void);
+
+  static tree_jit& instance (void);
+
   bool initialize (void);
 
+  bool do_execute (tree_simple_for_command& cmd, const octave_value& bounds);
+
+  bool do_execute (tree_while_command& cmd);
+
+  bool do_execute (octave_user_function& fcn, const octave_value_list& args,
+                   octave_value_list& retval);
+
   size_t trip_count (const octave_value& bounds) const;
 
   llvm::Module *module;
@@ -344,6 +385,24 @@
 };
 
 class
+jit_function_info
+{
+public:
+  jit_function_info (tree_jit& tjit, octave_user_function& fcn,
+                     const octave_value_list& ov_args);
+
+  bool execute (const octave_value_list& ov_args,
+                octave_value_list& retval) const;
+
+  bool match (const octave_value_list& ov_args) const;
+private:
+  typedef octave_base_value *(*jited_function)(octave_base_value**);
+
+  std::vector<jit_type *> argument_types;
+  jited_function function;
+};
+
+class
 jit_info
 {
 public: