view src/pt-jit.h @ 14911:1e2196d0bea4

doc: Removed old FIXMEs
author Max Brister <max@2bass.com>
date Fri, 18 May 2012 08:11:00 -0600
parents a8f1e08de8fc
children c7071907a641
line wrap: on
line source

/*

Copyright (C) 2012 Max Brister <max@2bass.com>

This file is part of Octave.

Octave is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

Octave is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>.

*/

#if !defined (octave_tree_jit_h)
#define octave_tree_jit_h 1

#include <list>
#include <map>
#include <set>
#include <stdexcept>
#include <vector>

#include "Array.h"
#include "Range.h"
#include "pt-walk.h"

// -------------------- Current status --------------------
// Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized.
// However, there is no warning emitted on divide by 0. For example,
// a = 5;
// b = a * 5 + a;
//
// For other types all binary operations are compiled but not optimized. For
// example,
// a = [1 2 3]
// b = a + a;
// will compile to do_binary_op (a, a).
//
// for loops with ranges compile. For example,
// for i=1:1000
//   result = i + 1;
// endfor
// Will compile. Nested for loops with constant bounds are also supported.
//
// TODO:
// 1. Cleanup
// 2. Support if statements
// 3. Support iteration over matricies
// 4. Check error state
// 5. ...
// ---------------------------------------------------------


// we don't want to include llvm headers here, as they require __STDC_LIMIT_MACROS
// and __STDC_CONSTANT_MACROS be defined in the entire compilation unit
namespace llvm
{
  class Value;
  class Module;
  class FunctionPassManager;
  class PassManager;
  class ExecutionEngine;
  class Function;
  class BasicBlock;
  class LLVMContext;
  class Type;
  class GenericValue;
}

class octave_base_value;
class octave_value;
class tree;

// jit_range is compatable with the llvm range structure
struct
jit_range
{
  jit_range (void) {}

  jit_range (const Range& from) : base (from.base ()), limit (from.limit ()),
    inc (from.inc ()), nelem (from.nelem ())
    {}

  operator Range () const
  {
    return Range (base, limit, inc);
  }

  double base;
  double limit;
  double inc;
  octave_idx_type nelem;
};

// Used to keep track of estimated (infered) types during JIT. This is a
// hierarchical type system which includes both concrete and abstract types.
//
// Current, we only support any and scalar types. If we can't figure out what
// type a variable is, we assign it the any type. This allows us to generate
// code even for the case of poor type inference.
class
jit_type
{
public:
  jit_type (const std::string& n, bool fi, jit_type *mparent, llvm::Type *lt,
            int tid) :
    mname (n), finit (fi), p (mparent), llvm_type (lt), id (tid)
  {}

  // a user readable type name
  const std::string& name (void) const { return mname; }

  // do we need to initialize variables of this type, even if they are not
  // input arguments?
  bool force_init (void) const { return finit; }

  // a unique id for the type
  int type_id (void) const { return id; }

  // An abstract base type, may be null
  jit_type *parent (void) const { return p; }

  // convert to an llvm type
  llvm::Type *to_llvm (void) const { return llvm_type; }

  // how this type gets passed as a function argument
  llvm::Type *to_llvm_arg (void) const;
private:
  std::string mname;
  bool finit;
  jit_type *p;
  llvm::Type *llvm_type;
  int id;
};

// Keeps track of overloads for a builtin function. Used for both type inference
// and code generation.
class
jit_function
{
public:
  struct overload
  {
    overload (void) : function (0), can_error (true), result (0) {}

    overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) :
      function (f), can_error (e), result (r), arguments (1)
    {
      arguments[0] = arg0;
    }

    overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0,
              jit_type *arg1) : function (f), can_error (e), result (r),
                                arguments (2)
    {
      arguments[0] = arg0;
      arguments[1] = arg1;
    }

    llvm::Function *function;
    bool can_error;
    jit_type *result;
    std::vector<jit_type*> arguments;
  };

  void add_overload (const overload& func)
  {
    add_overload (func, func.arguments);
  }

  void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0)
  {
    overload ol (f, e, r, arg0);
    add_overload (ol);
  }

  void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0,
                     jit_type *arg1)
  {
    overload ol (f, e, r, arg0, arg1);
    add_overload (ol);
  }

  void add_overload (const overload& func,
                     const std::vector<jit_type*>& args);

  const overload& get_overload (const std::vector<jit_type *>& types) const;

  const overload& get_overload (jit_type *arg0) const
  {
    std::vector<jit_type *> types (1);
    types[0] = arg0;
    return get_overload (types);
  }

  const overload& get_overload (jit_type *arg0, jit_type *arg1) const
  {
    std::vector<jit_type *> types (2);
    types[0] = arg0;
    types[1] = arg1;
    return get_overload (types);
  }

  jit_type *get_result (const std::vector<jit_type *>& types) const
  {
    const overload& temp = get_overload (types);
    return temp.result;
  }

  jit_type *get_result (jit_type *arg0, jit_type *arg1) const
  {
    const overload& temp = get_overload (arg0, arg1);
    return temp.result;
  }
private:
  Array<octave_idx_type> to_idx (const std::vector<jit_type*>& types) const;

  std::vector<Array<overload> > overloads;
};

// Get information and manipulate jit types.
class
jit_typeinfo
{
public:
  jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e);

  jit_type *get_any (void) const { return any; }

  jit_type *get_scalar (void) const { return scalar; }

  llvm::Type *get_scalar_llvm (void) const { return scalar->to_llvm (); }

  jit_type *get_range (void) const { return range; }

  llvm::Type *get_range_llvm (void) const { return range->to_llvm (); }

  jit_type *get_bool (void) const { return boolean; }

  jit_type *get_index (void) const { return index; }

  llvm::Type *get_index_llvm (void) const { return index->to_llvm (); }

  jit_type *type_of (const octave_value& ov) const;

  const jit_function& binary_op (int op) const;

  const jit_function::overload& binary_op_overload (int op, jit_type *lhs,
                                                    jit_type *rhs) const
    {
      const jit_function& jf = binary_op (op);
      return jf.get_overload (lhs, rhs);
    }

  jit_type *binary_op_result (int op, jit_type *lhs, jit_type *rhs) const
    {
      const jit_function::overload& ol = binary_op_overload (op, lhs, rhs);
      return ol.result;
    }

  const jit_function::overload& assign_op (jit_type *lhs, jit_type *rhs) const;

  const jit_function::overload& print_value (jit_type *to_print) const;

  const jit_function::overload& get_simple_for_check (jit_type *bounds) const
  {
    return simple_for_check.get_overload (bounds, index);
  }

  const jit_function::overload& get_simple_for_index (jit_type *bounds) const
  {
    return simple_for_index.get_overload (bounds, index);
  }

  jit_type *get_simple_for_index_result (jit_type *bounds) const
  {
    const jit_function::overload& ol = get_simple_for_index (bounds);
    return ol.result;
  }

  void to_generic (jit_type *type, llvm::GenericValue& gv);

  void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov);

  octave_value to_octave_value (jit_type *type, llvm::GenericValue& gv);

  void reset_generic (void);
private:
  typedef std::map<std::string, jit_type *> type_map;

  jit_type *new_type (const std::string& name, bool force_init,
                      jit_type *parent, llvm::Type *llvm_type);

  void add_print (jit_type *ty, void *call);

  void add_binary_op (jit_type *ty, int op, int llvm_op);

  llvm::Module *module;
  llvm::ExecutionEngine *engine;
  int next_id;

  llvm::Type *ov_t;

  std::vector<jit_type*> id_to_type;
  jit_type *any;
  jit_type *scalar;
  jit_type *range;
  jit_type *boolean;
  jit_type *index;

  std::vector<jit_function> binary_ops;
  jit_function assign_fn;
  jit_function print_fn;
  jit_function simple_for_check;
  jit_function simple_for_incr;
  jit_function simple_for_index;

  std::list<double> scalar_out;
  std::list<octave_base_value *> ov_out;
  std::list<jit_range> range_out;
};

class
jit_infer : public tree_walker
{
  typedef std::map<std::string, jit_type *> type_map;
public:
  jit_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false),
                                  rvalue_type (0)
  {}

  const std::set<std::string>& get_argin () const { return argin; }

  const type_map& get_types () const { return types; }

  void infer (tree_simple_for_command& cmd, jit_type *bounds);

  void visit_anon_fcn_handle (tree_anon_fcn_handle&);

  void visit_argument_list (tree_argument_list&);

  void visit_binary_expression (tree_binary_expression&);

  void visit_break_command (tree_break_command&);

  void visit_colon_expression (tree_colon_expression&);

  void visit_continue_command (tree_continue_command&);

  void visit_global_command (tree_global_command&);

  void visit_persistent_command (tree_persistent_command&);

  void visit_decl_elt (tree_decl_elt&);

  void visit_decl_init_list (tree_decl_init_list&);

  void visit_simple_for_command (tree_simple_for_command&);

  void visit_complex_for_command (tree_complex_for_command&);

  void visit_octave_user_script (octave_user_script&);

  void visit_octave_user_function (octave_user_function&);

  void visit_octave_user_function_header (octave_user_function&);

  void visit_octave_user_function_trailer (octave_user_function&);

  void visit_function_def (tree_function_def&);

  void visit_identifier (tree_identifier&);

  void visit_if_clause (tree_if_clause&);

  void visit_if_command (tree_if_command&);

  void visit_if_command_list (tree_if_command_list&);

  void visit_index_expression (tree_index_expression&);

  void visit_matrix (tree_matrix&);

  void visit_cell (tree_cell&);

  void visit_multi_assignment (tree_multi_assignment&);

  void visit_no_op_command (tree_no_op_command&);

  void visit_constant (tree_constant&);

  void visit_fcn_handle (tree_fcn_handle&);

  void visit_parameter_list (tree_parameter_list&);

  void visit_postfix_expression (tree_postfix_expression&);

  void visit_prefix_expression (tree_prefix_expression&);

  void visit_return_command (tree_return_command&);

  void visit_return_list (tree_return_list&);

  void visit_simple_assignment (tree_simple_assignment&);

  void visit_statement (tree_statement&);

  void visit_statement_list (tree_statement_list&);

  void visit_switch_case (tree_switch_case&);

  void visit_switch_case_list (tree_switch_case_list&);

  void visit_switch_command (tree_switch_command&);

  void visit_try_catch_command (tree_try_catch_command&);

  void visit_unwind_protect_command (tree_unwind_protect_command&);

  void visit_while_command (tree_while_command&);

  void visit_do_until_command (tree_do_until_command&);
private:
  void infer_simple_for (tree_simple_for_command& cmd,
                         jit_type *bounds);

  void handle_identifier (const std::string& name, octave_value v);

  jit_typeinfo *tinfo;

  bool is_lvalue;
  jit_type *rvalue_type;

  type_map types;
  std::set<std::string> argin;

  std::vector<jit_type *> type_stack;
};

class
jit_generator : public tree_walker
{
  typedef std::map<std::string, jit_type *> type_map;
public:
  jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee,
                 const std::set<std::string>& argin,
                 const type_map& infered_types, bool have_bounds = true);

  llvm::Function *get_function () const { return function; }

  void visit_anon_fcn_handle (tree_anon_fcn_handle&);

  void visit_argument_list (tree_argument_list&);

  void visit_binary_expression (tree_binary_expression&);

  void visit_break_command (tree_break_command&);

  void visit_colon_expression (tree_colon_expression&);

  void visit_continue_command (tree_continue_command&);

  void visit_global_command (tree_global_command&);

  void visit_persistent_command (tree_persistent_command&);

  void visit_decl_elt (tree_decl_elt&);

  void visit_decl_init_list (tree_decl_init_list&);

  void visit_simple_for_command (tree_simple_for_command&);

  void visit_complex_for_command (tree_complex_for_command&);

  void visit_octave_user_script (octave_user_script&);

  void visit_octave_user_function (octave_user_function&);

  void visit_octave_user_function_header (octave_user_function&);

  void visit_octave_user_function_trailer (octave_user_function&);

  void visit_function_def (tree_function_def&);

  void visit_identifier (tree_identifier&);

  void visit_if_clause (tree_if_clause&);

  void visit_if_command (tree_if_command&);

  void visit_if_command_list (tree_if_command_list&);

  void visit_index_expression (tree_index_expression&);

  void visit_matrix (tree_matrix&);

  void visit_cell (tree_cell&);

  void visit_multi_assignment (tree_multi_assignment&);

  void visit_no_op_command (tree_no_op_command&);

  void visit_constant (tree_constant&);

  void visit_fcn_handle (tree_fcn_handle&);

  void visit_parameter_list (tree_parameter_list&);

  void visit_postfix_expression (tree_postfix_expression&);

  void visit_prefix_expression (tree_prefix_expression&);

  void visit_return_command (tree_return_command&);

  void visit_return_list (tree_return_list&);

  void visit_simple_assignment (tree_simple_assignment&);

  void visit_statement (tree_statement&);

  void visit_statement_list (tree_statement_list&);

  void visit_switch_case (tree_switch_case&);

  void visit_switch_case_list (tree_switch_case_list&);

  void visit_switch_command (tree_switch_command&);

  void visit_try_catch_command (tree_try_catch_command&);

  void visit_unwind_protect_command (tree_unwind_protect_command&);

  void visit_while_command (tree_while_command&);

  void visit_do_until_command (tree_do_until_command&);
private:
  typedef std::pair<jit_type *, llvm::Value *> value;

  void emit_simple_for (tree_simple_for_command& cmd, value over,
                        bool atleast_once);

  void emit_print (const std::string& name, const value& v);

  void push_value (jit_type *type, llvm::Value *v)
  {
    value_stack.push_back (value (type, v));
  }

  jit_typeinfo *tinfo;
  llvm::Function *function;

  bool is_lvalue;
  std::map<std::string, value> variables;
  std::vector<value> value_stack;
};

class
tree_jit
{
public:
  tree_jit (void);

  ~tree_jit (void);

  bool execute (tree_simple_for_command& cmd, const octave_value& bounds);

  jit_typeinfo *get_typeinfo (void) const { return tinfo; }

  llvm::ExecutionEngine *get_engine (void) const { return engine; }

  llvm::Module *get_module (void) const { return module; }

  void optimize (llvm::Function *fn);
 private:
  bool initialize (void);

  llvm::LLVMContext &context;
  llvm::Module *module;
  llvm::PassManager *module_pass_manager;
  llvm::FunctionPassManager *pass_manager;
  llvm::ExecutionEngine *engine;

  jit_typeinfo *tinfo;
};

class
jit_info
{
public:
  jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds);

  bool execute (const octave_value& bounds) const;

  bool match (void) const;
private:
  typedef std::map<std::string, jit_type *> type_map;
  
  jit_typeinfo *tinfo;
  llvm::ExecutionEngine *engine;
  std::set<std::string> argin;
  type_map types;
  llvm::Function *function;
};

#endif