diff src/pt-jit.h @ 14903:54ea692b8ab5

Reworking JIT implementation src/TEMPLATE-INST/Array-jit.cc: New file. src/TEMPLATE-INST/module.mk: Add Array-jit.cc. src/ov-base.h (octave_base_value::grab, octave_base_value::release): New functions. src/pt-jit.cc: Rewrite. src/pt-jit.h: Rewrite.
author Max Brister <max@2bass.com>
date Sat, 12 May 2012 19:24:32 -0600
parents 516b4a15b775
children 3f81e8b42955
line wrap: on
line diff
--- a/src/pt-jit.h	Wed May 09 12:53:41 2012 -0600
+++ b/src/pt-jit.h	Sat May 12 19:24:32 2012 -0600
@@ -23,31 +23,247 @@
 #if !defined (octave_tree_jit_h)
 #define octave_tree_jit_h 1
 
+#include <list>
 #include <map>
+#include <set>
 #include <stdexcept>
 #include <vector>
 
+#include "Array.h"
 #include "pt-walk.h"
 
-class jit_fail_exception : public std::exception {};
+// -------------------- Current status --------------------
+// Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized.
+// However, there is no warning emitted on divide by 0. For example,
+// a = 5;
+// b = a * 5 + a;
+//
+// For other types all binary operations are compiled but not optimized. For
+// example,
+// a = [1 2 3]
+// b = a + a;
+// will compile to do_binary_op (a, a).
+// ---------------------------------------------------------
 
-// LLVM forward declares
+
+// we don't want to include llvm headers here, as they require __STDC_LIMIT_MACROS
+// and __STDC_CONSTANT_MACROS be defined in the entire compilation unit
 namespace llvm
 {
   class Value;
   class Module;
   class FunctionPassManager;
+  class PassManager;
   class ExecutionEngine;
   class Function;
   class BasicBlock;
   class LLVMContext;
+  class Type;
+  class GenericValue;
 }
 
+class octave_base_value;
+class octave_value;
 class tree;
 
+// thrown when we should give up on JIT and interpret
+class jit_fail_exception : public std::exception {};
+
+// Used to keep track of estimated (infered) types during JIT. This is a
+// hierarchical type system which includes both concrete and abstract types.
+//
+// Current, we only support any and scalar types. If we can't figure out what
+// type a variable is, we assign it the any type. This allows us to generate
+// code even for the case of poor type inference.
+class
+OCTINTERP_API
+jit_type
+{
+public:
+  jit_type (const std::string& n, bool fi, jit_type *mparent, llvm::Type *lt,
+            int tid) :
+    mname (n), finit (fi), p (mparent), llvm_type (lt), id (tid)
+  {}
+
+  // a user readable type name
+  const std::string& name (void) const { return mname; }
+
+  // do we need to initialize variables of this type, even if they are not
+  // input arguments?
+  bool force_init (void) const { return finit; }
+
+  // a unique id for the type
+  int type_id (void) const { return id; }
+
+  // An abstract base type, may be null
+  jit_type *parent (void) const { return p; }
+
+  // convert to an llvm type
+  llvm::Type *to_llvm (void) const { return llvm_type; }
+
+  // how this type gets passed as a function argument
+  llvm::Type *to_llvm_arg (void) const;
+private:
+  std::string mname;
+  bool finit;
+  jit_type *p;
+  llvm::Type *llvm_type;
+  int id;
+  int depth;
+};
+
+
+// Keeps track of overloads for a builtin function. Used for both type inference
+// and code generation.
+class
+jit_function
+{
+public:
+  struct overload
+  {
+    overload (void) : function (0), can_error (true), result (0) {}
+
+    overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) :
+      function (f), can_error (e), result (r), arguments (1)
+    {
+      arguments[0] = arg0;
+    }
+
+    overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0,
+              jit_type *arg1) : function (f), can_error (e), result (r),
+                                arguments (2)
+    {
+      arguments[0] = arg0;
+      arguments[1] = arg1;
+    }
+
+    llvm::Function *function;
+    bool can_error;
+    jit_type *result;
+    std::vector<jit_type*> arguments;
+  };
+
+  void add_overload (const overload& func)
+  {
+    add_overload (func, func.arguments);
+  }
+
+  void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0,
+                     jit_type *arg1)
+  {
+    overload ol (f, e, r, arg0, arg1);
+    add_overload (ol);
+  }
+
+  void add_overload (const overload& func,
+                     const std::vector<jit_type*>& args);
+
+  const overload& get_overload (const std::vector<jit_type *>& types) const;
+
+  const overload& get_overload (jit_type *arg0) const
+  {
+    std::vector<jit_type *> types (1);
+    types[0] = arg0;
+    return get_overload (types);
+  }
+
+  const overload& get_overload (jit_type *arg0, jit_type *arg1) const
+  {
+    std::vector<jit_type *> types (2);
+    types[0] = arg0;
+    types[1] = arg1;
+    return get_overload (types);
+  }
+
+  jit_type *get_result (const std::vector<jit_type *>& types) const
+  {
+    const overload& temp = get_overload (types);
+    return temp.result;
+  }
+
+  jit_type *get_result (jit_type *arg0, jit_type *arg1) const
+  {
+    const overload& temp = get_overload (arg0, arg1);
+    return temp.result;
+  }
+private:
+  Array<octave_idx_type> to_idx (const std::vector<jit_type*>& types) const;
+
+  std::vector<Array<overload> > overloads;
+};
+
+// Get information and manipulate jit types.
+class
+OCTINTERP_API
+jit_typeinfo
+{
+public:
+  jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e, llvm::Type *ov);
+
+  jit_type *get_any (void) const { return any; }
+
+  jit_type *get_scalar (void) const { return scalar; }
+
+  jit_type *type_of (const octave_value& ov) const;
+
+  const jit_function& binary_op (int op) const;
+
+  const jit_function::overload& binary_op_overload (int op, jit_type *lhs,
+                                                    jit_type *rhs) const
+    {
+      const jit_function& jf = binary_op (op);
+      return jf.get_overload (lhs, rhs);
+    }
+
+  jit_type *binary_op_result (int op, jit_type *lhs, jit_type *rhs) const
+    {
+      const jit_function::overload& ol = binary_op_overload (op, lhs, rhs);
+      return ol.result;
+    }
+
+  const jit_function::overload& assign_op (jit_type *lhs, jit_type *rhs) const;
+
+  const jit_function::overload& print_value (jit_type *to_print) const;
+
+  // FIXME: generic creation should probably be handled seperatly
+  void to_generic (jit_type *type, llvm::GenericValue& gv);
+  void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov);
+
+  octave_value to_octave_value (jit_type *type, llvm::GenericValue& gv);
+
+  void reset_generic (size_t nargs);
+private:
+  jit_type *new_type (const std::string& name, bool force_init,
+                      jit_type *parent, llvm::Type *llvm_type);
+
+  void add_print (jit_type *ty, void *call);
+
+  void add_binary_op (jit_type *ty, int op, int llvm_op);
+
+  llvm::Module *module;
+  llvm::ExecutionEngine *engine;
+  int next_id;
+
+  llvm::Type *ov_t;
+
+  std::vector<jit_type*> id_to_type;
+  jit_type *any;
+  jit_type *scalar;
+
+  std::vector<jit_function> binary_ops;
+  jit_function assign_fn;
+  jit_function print_fn;
+
+  size_t scalar_out_idx;
+  std::vector<double> scalar_out;
+
+  size_t ov_out_idx;
+  std::vector<octave_base_value*> ov_out;
+};
+
 class
 OCTINTERP_API
-tree_jit : private tree_walker
+tree_jit
 {
 public:
   tree_jit (void);
@@ -56,146 +272,265 @@
 
   bool execute (tree& tee);
  private:
-  typedef void (*jit_function)(bool*, double*);
+  typedef std::map<std::string, jit_type *> type_map;
+
+  class
+  type_infer : public tree_walker
+  {
+  public:
+    type_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false),
+                                    rvalue_type (0)
+    {}
+
+    const std::set<std::string>& get_argin () const { return argin; }
+
+    const type_map& get_types () const { return types; }
+
+    void visit_anon_fcn_handle (tree_anon_fcn_handle&);
+
+    void visit_argument_list (tree_argument_list&);
+
+    void visit_binary_expression (tree_binary_expression&);
+
+    void visit_break_command (tree_break_command&);
+
+    void visit_colon_expression (tree_colon_expression&);
+
+    void visit_continue_command (tree_continue_command&);
+
+    void visit_global_command (tree_global_command&);
+
+    void visit_persistent_command (tree_persistent_command&);
+
+    void visit_decl_elt (tree_decl_elt&);
+
+    void visit_decl_init_list (tree_decl_init_list&);
+
+    void visit_simple_for_command (tree_simple_for_command&);
+
+    void visit_complex_for_command (tree_complex_for_command&);
+
+    void visit_octave_user_script (octave_user_script&);
+
+    void visit_octave_user_function (octave_user_function&);
+
+    void visit_octave_user_function_header (octave_user_function&);
+
+    void visit_octave_user_function_trailer (octave_user_function&);
+
+    void visit_function_def (tree_function_def&);
+
+    void visit_identifier (tree_identifier&);
+
+    void visit_if_clause (tree_if_clause&);
+
+    void visit_if_command (tree_if_command&);
+
+    void visit_if_command_list (tree_if_command_list&);
+
+    void visit_index_expression (tree_index_expression&);
+
+    void visit_matrix (tree_matrix&);
+
+    void visit_cell (tree_cell&);
+
+    void visit_multi_assignment (tree_multi_assignment&);
+
+    void visit_no_op_command (tree_no_op_command&);
+
+    void visit_constant (tree_constant&);
+
+    void visit_fcn_handle (tree_fcn_handle&);
+
+    void visit_parameter_list (tree_parameter_list&);
+
+    void visit_postfix_expression (tree_postfix_expression&);
+
+    void visit_prefix_expression (tree_prefix_expression&);
+
+    void visit_return_command (tree_return_command&);
+
+    void visit_return_list (tree_return_list&);
+
+    void visit_simple_assignment (tree_simple_assignment&);
+
+    void visit_statement (tree_statement&);
+
+    void visit_statement_list (tree_statement_list&);
+
+    void visit_switch_case (tree_switch_case&);
+
+    void visit_switch_case_list (tree_switch_case_list&);
+
+    void visit_switch_command (tree_switch_command&);
+
+    void visit_try_catch_command (tree_try_catch_command&);
+
+    void visit_unwind_protect_command (tree_unwind_protect_command&);
+
+    void visit_while_command (tree_while_command&);
+
+    void visit_do_until_command (tree_do_until_command&);
+  private:
+    void handle_identifier (const std::string& name, octave_value v);
+
+    jit_typeinfo *tinfo;
+
+    bool is_lvalue;
+    jit_type *rvalue_type;
+
+    type_map types;
+    std::set<std::string> argin;
+
+    std::vector<jit_type *> type_stack;
+  };
+
+  class
+  code_generator : public tree_walker
+  {
+  public:
+    code_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee,
+                    const std::set<std::string>& argin,
+                    const type_map& infered_types);
+
+    llvm::Function *get_function () const { return function; }
+
+    void visit_anon_fcn_handle (tree_anon_fcn_handle&);
+
+    void visit_argument_list (tree_argument_list&);
+
+    void visit_binary_expression (tree_binary_expression&);
+
+    void visit_break_command (tree_break_command&);
+
+    void visit_colon_expression (tree_colon_expression&);
+
+    void visit_continue_command (tree_continue_command&);
+
+    void visit_global_command (tree_global_command&);
+
+    void visit_persistent_command (tree_persistent_command&);
+
+    void visit_decl_elt (tree_decl_elt&);
+
+    void visit_decl_init_list (tree_decl_init_list&);
+
+    void visit_simple_for_command (tree_simple_for_command&);
+
+    void visit_complex_for_command (tree_complex_for_command&);
+
+    void visit_octave_user_script (octave_user_script&);
+
+    void visit_octave_user_function (octave_user_function&);
+
+    void visit_octave_user_function_header (octave_user_function&);
+
+    void visit_octave_user_function_trailer (octave_user_function&);
+
+    void visit_function_def (tree_function_def&);
+
+    void visit_identifier (tree_identifier&);
+
+    void visit_if_clause (tree_if_clause&);
+
+    void visit_if_command (tree_if_command&);
+
+    void visit_if_command_list (tree_if_command_list&);
+
+    void visit_index_expression (tree_index_expression&);
+
+    void visit_matrix (tree_matrix&);
+
+    void visit_cell (tree_cell&);
+
+    void visit_multi_assignment (tree_multi_assignment&);
+
+    void visit_no_op_command (tree_no_op_command&);
+
+    void visit_constant (tree_constant&);
+
+    void visit_fcn_handle (tree_fcn_handle&);
+
+    void visit_parameter_list (tree_parameter_list&);
+
+    void visit_postfix_expression (tree_postfix_expression&);
+
+    void visit_prefix_expression (tree_prefix_expression&);
+
+    void visit_return_command (tree_return_command&);
+
+    void visit_return_list (tree_return_list&);
+
+    void visit_simple_assignment (tree_simple_assignment&);
+
+    void visit_statement (tree_statement&);
+
+    void visit_statement_list (tree_statement_list&);
+
+    void visit_switch_case (tree_switch_case&);
+
+    void visit_switch_case_list (tree_switch_case_list&);
+
+    void visit_switch_command (tree_switch_command&);
+
+    void visit_try_catch_command (tree_try_catch_command&);
+
+    void visit_unwind_protect_command (tree_unwind_protect_command&);
+
+    void visit_while_command (tree_while_command&);
+
+    void visit_do_until_command (tree_do_until_command&);
+  private:
+    typedef std::pair<jit_type *, llvm::Value *> value;
+
+    void emit_print (const std::string& name, const value& v);
+
+    void push_value (jit_type *type, llvm::Value *v)
+    {
+      value_stack.push_back (value (type, v));
+    }
+
+    jit_typeinfo *tinfo;
+    llvm::Function *function;
+
+    bool is_lvalue;
+    std::map<std::string, value> variables;
+    std::vector<value> value_stack;
+  };
 
   class function_info
   {
   public:
-    function_info (void);
-    function_info (jit_function fn, const std::vector<std::string>& args,
-                   const std::vector<bool>& args_used);
+    function_info (tree_jit& tjit, tree& tee);
 
-    bool execute ();
+    bool execute () const;
+
+    bool match () const;
   private:
-    jit_function function;
-    std::vector<std::string> arguments;
-    
-    // is the argument used? or is it just declared?
-    std::vector<bool> argument_used;
-  };
-
-  struct variable_info
-  {
-    llvm::Value *defined;
-    llvm::Value *value;
-    bool use;
+    jit_typeinfo *tinfo;
+    llvm::ExecutionEngine *engine;
+    std::set<std::string> argin;
+    type_map types;
+    llvm::Function *function;
   };
 
-  function_info *compile (tree& tee);
-
-  variable_info find (const std::string &name, bool use);
-
-  void do_assign (variable_info vinfo, llvm::Value *value);
-
-  void emit_print (const std::string& vname, llvm::Value *value);
-
-  // tree_walker
-  void visit_anon_fcn_handle (tree_anon_fcn_handle&);
-
-  void visit_argument_list (tree_argument_list&);
-
-  void visit_binary_expression (tree_binary_expression&);
-
-  void visit_break_command (tree_break_command&);
-
-  void visit_colon_expression (tree_colon_expression&);
-
-  void visit_continue_command (tree_continue_command&);
-
-  void visit_global_command (tree_global_command&);
-
-  void visit_persistent_command (tree_persistent_command&);
-
-  void visit_decl_elt (tree_decl_elt&);
-
-  void visit_decl_init_list (tree_decl_init_list&);
-
-  void visit_simple_for_command (tree_simple_for_command&);
-
-  void visit_complex_for_command (tree_complex_for_command&);
-
-  void visit_octave_user_script (octave_user_script&);
-
-  void visit_octave_user_function (octave_user_function&);
-
-  void visit_octave_user_function_header (octave_user_function&);
-
-  void visit_octave_user_function_trailer (octave_user_function&);
-
-  void visit_function_def (tree_function_def&);
-
-  void visit_identifier (tree_identifier&);
-
-  void visit_if_clause (tree_if_clause&);
-
-  void visit_if_command (tree_if_command&);
-
-  void visit_if_command_list (tree_if_command_list&);
-
-  void visit_index_expression (tree_index_expression&);
+  typedef std::list<function_info *> function_list;
+  typedef std::map<tree *, function_list> compiled_map;
 
-  void visit_matrix (tree_matrix&);
-
-  void visit_cell (tree_cell&);
-
-  void visit_multi_assignment (tree_multi_assignment&);
-
-  void visit_no_op_command (tree_no_op_command&);
-
-  void visit_constant (tree_constant&);
-
-  void visit_fcn_handle (tree_fcn_handle&);
-
-  void visit_parameter_list (tree_parameter_list&);
-
-  void visit_postfix_expression (tree_postfix_expression&);
-
-  void visit_prefix_expression (tree_prefix_expression&);
-
-  void visit_return_command (tree_return_command&);
-
-  void visit_return_list (tree_return_list&);
-
-  void visit_simple_assignment (tree_simple_assignment&);
-
-  void visit_statement (tree_statement&);
-
-  void visit_statement_list (tree_statement_list&);
-
-  void visit_switch_case (tree_switch_case&);
-
-  void visit_switch_case_list (tree_switch_case_list&);
-
-  void visit_switch_command (tree_switch_command&);
-
-  void visit_try_catch_command (tree_try_catch_command&);
-
-  void do_unwind_protect_cleanup_code (tree_statement_list *list);
-
-  void visit_unwind_protect_command (tree_unwind_protect_command&);
-
-  void visit_while_command (tree_while_command&);
-
-  void visit_do_until_command (tree_do_until_command&);
-
-  void fail (void);
-
-  typedef std::map<std::string, variable_info> var_map;
-  typedef var_map::iterator var_map_iterator;
-  typedef std::map<tree*, function_info*> finfo_map;
-  typedef finfo_map::iterator finfo_map_iterator;
-
-  std::vector<llvm::Value*> value_stack;
-  var_map variables;
-  finfo_map compiled_functions;
+  static void fail (void)
+  {
+    throw jit_fail_exception ();
+  }
 
   llvm::LLVMContext &context;
   llvm::Module *module;
+  llvm::PassManager *module_pass_manager;
   llvm::FunctionPassManager *pass_manager;
   llvm::ExecutionEngine *engine;
-  llvm::BasicBlock *entry_block;
 
-  llvm::Function *print_double;
+  jit_typeinfo *tinfo;
+
+  compiled_map compiled;
 };
 
 #endif