view src/pt-jit.cc @ 14911:1e2196d0bea4

doc: Removed old FIXMEs
author Max Brister <max@2bass.com>
date Fri, 18 May 2012 08:11:00 -0600
parents a8f1e08de8fc
children c7071907a641
line wrap: on
line source

/*

Copyright (C) 2012 Max Brister <max@2bass.com>

This file is part of Octave.

Octave is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

Octave is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>.

*/

#define __STDC_LIMIT_MACROS
#define __STDC_CONSTANT_MACROS

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "pt-jit.h"

#include <typeinfo>

#include <llvm/LLVMContext.h>
#include <llvm/Module.h>
#include <llvm/Function.h>
#include <llvm/BasicBlock.h>
#include <llvm/Support/IRBuilder.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JIT.h>
#include <llvm/PassManager.h>
#include <llvm/Analysis/Verifier.h>
#include <llvm/Analysis/CallGraph.h>
#include <llvm/Analysis/Passes.h>
#include <llvm/Target/TargetData.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_os_ostream.h>
#include <llvm/ExecutionEngine/GenericValue.h>

#include "octave.h"
#include "ov-fcn-handle.h"
#include "ov-usr-fcn.h"
#include "pt-all.h"

// FIXME: Remove eventually
// For now we leave this in so people tell when JIT actually happens
static const bool debug_print = false;

static llvm::IRBuilder<> builder (llvm::getGlobalContext ());

// thrown when we should give up on JIT and interpret
class jit_fail_exception : public std::exception {};

static void
fail (void)
{
  throw jit_fail_exception ();
}

// function that jit code calls
extern "C" void
octave_jit_print_any (const char *name, octave_base_value *obv)
{
  obv->print_with_name (octave_stdout, name, true);
}

extern "C" void
octave_jit_print_double (const char *name, double value)
{
  // FIXME: We should avoid allocating a new octave_scalar each time
  octave_value ov (value);
  ov.print_with_name (octave_stdout, name);
}

extern "C" octave_base_value*
octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs,
                           octave_base_value *rhs)
{
  octave_value olhs (lhs, true);
  octave_value orhs (rhs, true);
  octave_value result = do_binary_op (op, olhs, orhs);
  octave_base_value *rep = result.internal_rep ();
  rep->grab ();
  return rep;
}

extern "C" void
octave_jit_assign_any_any_help (octave_base_value *lhs, octave_base_value *rhs)
{
  if (lhs != rhs)
    {
      rhs->grab ();
      lhs->release ();
    }
}

// -------------------- jit_type --------------------
llvm::Type *
jit_type::to_llvm_arg (void) const
{
  return llvm_type ? llvm_type->getPointerTo () : 0;
}

// -------------------- jit_function --------------------
void
jit_function::add_overload (const overload& func,
                            const std::vector<jit_type*>& args)
{
  if (args.size () >= overloads.size ())
    overloads.resize (args.size () + 1);

  Array<overload>& over = overloads[args.size ()];
  dim_vector dv (over.dims ());
  Array<octave_idx_type> idx = to_idx (args);
  bool must_resize = false;

  if (dv.length () != idx.numel ())
    {
      dv.resize (idx.numel ());
      must_resize = true;
    }

  for (octave_idx_type i = 0; i < dv.length (); ++i)
    if (dv(i) <= idx(i))
      {
        must_resize = true;
        dv(i) = idx(i) + 1;
      }

  if (must_resize)
    over.resize (dv);

  over(idx) = func;
}

const jit_function::overload&
jit_function::get_overload (const std::vector<jit_type*>& types) const
{
  // FIXME: We should search for the next best overload on failure
  static overload null_overload;
  if (types.size () >= overloads.size ())
    return null_overload;

  const Array<overload>& over = overloads[types.size ()];
  dim_vector dv (over.dims ());
  Array<octave_idx_type> idx = to_idx (types);
  for (octave_idx_type i = 0; i < dv.length (); ++i)
    if (idx(i) >= dv(i))
      return null_overload;

  return over(idx);
}

Array<octave_idx_type>
jit_function::to_idx (const std::vector<jit_type*>& types) const
{
  octave_idx_type numel = types.size ();
  if (numel == 1)
    numel = 2;

  Array<octave_idx_type> idx (dim_vector (1, numel));
  for (octave_idx_type i = 0; i < static_cast<octave_idx_type> (types.size ());
       ++i)
    idx(i) = types[i]->type_id ();

  if (types.size () == 1)
    {
      idx(1) = idx(0);
      idx(0) = 0;
    }

  return idx;
}

// -------------------- jit_typeinfo --------------------
jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e)
  : module (m), engine (e), next_id (0)
{
  // FIXME: We should be registering types like in octave_value_typeinfo
  llvm::LLVMContext &ctx = m->getContext ();

  ov_t = llvm::StructType::create (ctx, "octave_base_value");
  ov_t = ov_t->getPointerTo ();

  llvm::Type *dbl = llvm::Type::getDoubleTy (ctx);
  llvm::Type *bool_t = llvm::Type::getInt1Ty (ctx);
  llvm::Type *index_t = 0;
  switch (sizeof(octave_idx_type))
    {
    case 4:
      index_t = llvm::Type::getInt32Ty (ctx);
      break;
    case 8:
      index_t = llvm::Type::getInt64Ty (ctx);
      break;
    default:
      assert (false && "Unrecognized index type size");
    }

  llvm::StructType *range_t = llvm::StructType::create (ctx, "range");
  std::vector<llvm::Type *> range_contents (4, dbl);
  range_contents[3] = index_t;
  range_t->setBody (range_contents);

  // create types
  any = new_type ("any", true, 0, ov_t);
  scalar = new_type ("scalar", false, any, dbl);
  range = new_type ("range", false, any, range_t);
  boolean = new_type ("bool", false, any, bool_t);
  index = new_type ("index", false, any, index_t);

  // any with anything is an any op
  llvm::Type *binary_op_type
    = llvm::Type::getIntNTy (ctx, sizeof (octave_value::binary_op));
  std::vector<llvm::Type*> args (3);
  args[0] = binary_op_type;
  args[1] = args[2] = any->to_llvm ();
  llvm::FunctionType *any_binary_t = llvm::FunctionType::get (ov_t, args, false);
  llvm::Function *any_binary = llvm::Function::Create (any_binary_t,
                                                       llvm::Function::ExternalLinkage,
                                                       "octave_jit_binary_any_any",
                                                       module);
  engine->addGlobalMapping (any_binary,
                            reinterpret_cast<void*>(&octave_jit_binary_any_any));

  args.resize (2);
  args[0] = any->to_llvm ();
  args[1] = any->to_llvm ();

  binary_ops.resize (octave_value::num_binary_ops);
  for (int op = 0; op < octave_value::num_binary_ops; ++op)
    {
      llvm::FunctionType *ftype = llvm::FunctionType::get (ov_t, args, false);

      llvm::Twine fn_name ("octave_jit_binary_any_any_");
      fn_name = fn_name + llvm::Twine (op);
      llvm::Function *fn = llvm::Function::Create (ftype,
                                                   llvm::Function::ExternalLinkage,
                                                   fn_name, module);
      llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn);
      builder.SetInsertPoint (block);
      llvm::APInt op_int(sizeof (octave_value::binary_op), op,
                         std::numeric_limits<octave_value::binary_op>::is_signed);
      llvm::Value *op_as_llvm = llvm::ConstantInt::get (binary_op_type, op_int);
      llvm::Value *ret = builder.CreateCall3 (any_binary,
                                                 op_as_llvm,
                                                 fn->arg_begin (),
                                                 ++fn->arg_begin ());
      builder.CreateRet (ret);

      jit_function::overload overload (fn, true, any, any, any);
      for (octave_idx_type i = 0; i < next_id; ++i)
        binary_ops[op].add_overload (overload);
    }

  // assign any = any
  llvm::Type *void_t = llvm::Type::getVoidTy (ctx);
  args.resize (2);
  args[0] = any->to_llvm ();
  args[1] = any->to_llvm ();
  llvm::FunctionType *ft = llvm::FunctionType::get (void_t, args, false);
  llvm::Function *fn_help = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                                                    "octave_jit_assign_any_any_help",
                                                    module);
  engine->addGlobalMapping (fn_help,
                            reinterpret_cast<void*>(&octave_jit_assign_any_any_help));

  args.resize (2);
  args[0] = any->to_llvm_arg ();
  args[1] = any->to_llvm ();
  ft = llvm::FunctionType::get (void_t, args, false);
  llvm::Function *fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                                               "octave_jit_assign_any_any",
                                               module);
  fn->addFnAttr (llvm::Attribute::AlwaysInline);
  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn);
  builder.SetInsertPoint (body);
  llvm::Value *value = builder.CreateLoad (fn->arg_begin ());
  builder.CreateCall2 (fn_help, value, ++fn->arg_begin ());
  builder.CreateStore (++fn->arg_begin (), fn->arg_begin ());
  builder.CreateRetVoid ();
  llvm::verifyFunction (*fn);
  assign_fn.add_overload (fn, false, 0, any, any);

  // assign scalar = scalar
  args.resize (2);
  args[0] = scalar->to_llvm_arg ();
  args[1] = scalar->to_llvm ();
  ft = llvm::FunctionType::get (void_t, args, false);
  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                               "octave_jit_assign_scalar_scalar", module);
  fn->addFnAttr (llvm::Attribute::AlwaysInline);
  body = llvm::BasicBlock::Create (ctx, "body", fn);
  builder.SetInsertPoint (body);
  builder.CreateStore (++fn->arg_begin (), fn->arg_begin ());
  builder.CreateRetVoid ();
  llvm::verifyFunction (*fn);
  assign_fn.add_overload (fn, false, 0, scalar, scalar);

  // now for binary scalar operations
  // FIXME: Finish all operations
  add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd);
  add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub);
  add_binary_op (scalar, octave_value::op_mul, llvm::Instruction::FMul);
  add_binary_op (scalar, octave_value::op_el_mul, llvm::Instruction::FMul);

  // FIXME: Warn if rhs is zero
  add_binary_op (scalar, octave_value::op_div, llvm::Instruction::FDiv);
  add_binary_op (scalar, octave_value::op_el_div, llvm::Instruction::FDiv);

  // now for printing functions
  add_print (any, reinterpret_cast<void*> (&octave_jit_print_any));
  add_print (scalar, reinterpret_cast<void*> (&octave_jit_print_double));

  // bounds check for for loop
  args.resize (2);
  args[0] = range->to_llvm ();
  args[1] = index->to_llvm ();
  ft = llvm::FunctionType::get (bool_t, args, false);
  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                               "octave_jit_simple_for_range", module);
  fn->addFnAttr (llvm::Attribute::AlwaysInline);
  body = llvm::BasicBlock::Create (ctx, "body", fn);
  builder.SetInsertPoint (body);
  {
    llvm::Value *nelem
      = builder.CreateExtractValue (fn->arg_begin (), 3);
    // llvm::Value *idx = builder.CreateLoad (++fn->arg_begin ());
    llvm::Value *idx = ++fn->arg_begin ();
    llvm::Value *ret = builder.CreateICmpULT (idx, nelem);
    builder.CreateRet (ret);
  }
  llvm::verifyFunction (*fn);
  simple_for_check.add_overload (fn, false, boolean, range, index);

  // increment for for loop
  args.resize (1);
  args[0] = index->to_llvm ();
  ft = llvm::FunctionType::get (index->to_llvm (), args, false);
  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                               "octave_jit_imple_for_range_incr", module);
  fn->addFnAttr (llvm::Attribute::AlwaysInline);
  body = llvm::BasicBlock::Create (ctx, "body", fn);
  builder.SetInsertPoint (body);
  {
    llvm::Value *one = llvm::ConstantInt::get (index_t, 1);
    llvm::Value *idx = fn->arg_begin ();
    llvm::Value *ret = builder.CreateAdd (idx, one);
    builder.CreateRet (ret);
  }
  llvm::verifyFunction (*fn);
  simple_for_incr.add_overload (fn, false, index, index);

  // index variabe for for loop
  args.resize (2);
  args[0] = range->to_llvm ();
  args[1] = index->to_llvm ();
  ft = llvm::FunctionType::get (dbl, args, false);
  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                               "octave_jit_simple_for_idx", module);
  fn->addFnAttr (llvm::Attribute::AlwaysInline);
  body = llvm::BasicBlock::Create (ctx, "body", fn);
  builder.SetInsertPoint (body);
  {
    llvm::Value *idx = ++fn->arg_begin ();
    llvm::Value *didx = builder.CreateUIToFP (idx, dbl);
    llvm::Value *rng = fn->arg_begin ();
    llvm::Value *base = builder.CreateExtractValue (rng, 0);
    llvm::Value *inc = builder.CreateExtractValue (rng, 2);

    llvm::Value *ret = builder.CreateFMul (didx, inc);
    ret = builder.CreateFAdd (base, ret);
    builder.CreateRet (ret);
  }
  llvm::verifyFunction (*fn);
  simple_for_index.add_overload (fn, false, scalar, range, index);
}

void
jit_typeinfo::add_print (jit_type *ty, void *call)
{
  llvm::LLVMContext& ctx = llvm::getGlobalContext ();
  llvm::Type *void_t = llvm::Type::getVoidTy (ctx);
  std::vector<llvm::Type *> args (2);
  args[0] = llvm::Type::getInt8PtrTy (ctx);
  args[1] = ty->to_llvm ();

  std::stringstream name;
  name << "octave_jit_print_" << ty->name ();

  llvm::FunctionType *print_ty = llvm::FunctionType::get (void_t, args, false);
  llvm::Function *fn = llvm::Function::Create (print_ty,
                                               llvm::Function::ExternalLinkage,
                                               name.str (), module);
  engine->addGlobalMapping (fn, call);

  jit_function::overload ol (fn, false, 0, ty);
  print_fn.add_overload (ol);
}

void
jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op)
{
  llvm::LLVMContext& ctx = llvm::getGlobalContext ();
  std::vector<llvm::Type *> args (2, ty->to_llvm ());
  llvm::FunctionType *ft = llvm::FunctionType::get (ty->to_llvm (), args,
                                                    false);

  std::stringstream fname;
  octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op);
  fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op)
        << "_" << ty->name ();

  llvm::Function *fn = llvm::Function::Create (ft,
                                               llvm::Function::ExternalLinkage,
                                               fname.str (),
                                               module);
  fn->addFnAttr (llvm::Attribute::AlwaysInline);
  llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn);
  llvm::IRBuilder<> fn_builder (block);
  llvm::Instruction::BinaryOps temp
    = static_cast<llvm::Instruction::BinaryOps>(llvm_op);
  llvm::Value *ret = fn_builder.CreateBinOp (temp, fn->arg_begin (),
                                             ++fn->arg_begin ());
  fn_builder.CreateRet (ret);
  llvm::verifyFunction (*fn);

  jit_function::overload ol(fn, false, ty, ty, ty);
  binary_ops[op].add_overload (ol);
}

jit_type*
jit_typeinfo::type_of (const octave_value &ov) const
{
  if (ov.is_undefined () || ov.is_function ())
    return 0;

  if (ov.is_double_type () && ov.is_real_scalar ())
    return get_scalar ();

  if (ov.is_range ())
    return get_range ();

  return get_any ();
}

const jit_function&
jit_typeinfo::binary_op (int op) const
{
  return binary_ops[op];
}

const jit_function::overload&
jit_typeinfo::assign_op (jit_type *lhs, jit_type *rhs) const
{
  assert (lhs == rhs);
  return assign_fn.get_overload (lhs, rhs);
}

const jit_function::overload&
jit_typeinfo::print_value (jit_type *to_print) const
{
  return print_fn.get_overload (to_print);
}

void
jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv)
{
  // duplication here can probably be removed somehow
  if (type == any)
    to_generic (type, gv, octave_value ());
  else if (type == scalar)
    to_generic (type, gv, octave_value (0));
  else if (type == range)
    to_generic (type, gv, octave_value (Range ()));
  else
    assert (false && "Type not supported yet");
}

void
jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov)
{
  if (type == any)
    {
      octave_base_value *obv = ov.internal_rep ();
      obv->grab ();
      ov_out.push_back (obv);
      gv.PointerVal = &ov_out.back ();
    }
  else if (type == scalar)
    {
      scalar_out.push_back (ov.double_value ());
      gv.PointerVal = &scalar_out.back ();
    }
  else if (type == range)
    {
      range_out.push_back (ov.range_value ());
      gv.PointerVal = &range_out.back ();
    }
  else
    assert (false && "Type not supported yet");
}

octave_value
jit_typeinfo::to_octave_value (jit_type *type, llvm::GenericValue& gv)
{
  if (type == any)
    {
      octave_base_value **ptr = reinterpret_cast<octave_base_value **>(gv.PointerVal);
      return octave_value (*ptr);
    }
  else if (type == scalar)
    {
      double *ptr = reinterpret_cast<double *>(gv.PointerVal);
      return octave_value (*ptr);
    }
  else if (type == range)
    {
      jit_range *ptr = reinterpret_cast<jit_range *>(gv.PointerVal);
      Range rng = *ptr;
      return octave_value (rng);
    }
  else
    assert (false && "Type not supported yet");
}

void
jit_typeinfo::reset_generic (void)
{
  scalar_out.clear ();
  ov_out.clear ();
  range_out.clear ();
}

jit_type*
jit_typeinfo::new_type (const std::string& name, bool force_init,
                        jit_type *parent, llvm::Type *llvm_type)
{
  jit_type *ret = new jit_type (name, force_init, parent, llvm_type, next_id++);
  id_to_type.push_back (ret);
  return ret;
}

// -------------------- jit_infer --------------------
void
jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds)
{
  argin.insert ("#bounds");
  types["#bounds"] = bounds;

  infer_simple_for (cmd, bounds);
}

void
jit_infer::visit_anon_fcn_handle (tree_anon_fcn_handle&)
{
  fail ();
}

void
jit_infer::visit_argument_list (tree_argument_list&)
{
  fail ();
}

void
jit_infer::visit_binary_expression (tree_binary_expression& be)
{
  if (is_lvalue)
    fail ();

  tree_expression *lhs = be.lhs ();
  lhs->accept (*this);
  jit_type *tlhs = type_stack.back ();
  type_stack.pop_back ();

  tree_expression *rhs = be.rhs ();
  rhs->accept (*this);
  jit_type *trhs = type_stack.back ();

  jit_type *result = tinfo->binary_op_result (be.op_type (), tlhs, trhs);
  if (! result)
    fail ();

  type_stack.push_back (result);
}

void
jit_infer::visit_break_command (tree_break_command&)
{
  fail ();
}

void
jit_infer::visit_colon_expression (tree_colon_expression&)
{
  fail ();
}

void
jit_infer::visit_continue_command (tree_continue_command&)
{
  fail ();
}

void
jit_infer::visit_global_command (tree_global_command&)
{
  fail ();
}

void
jit_infer::visit_persistent_command (tree_persistent_command&)
{
  fail ();
}

void
jit_infer::visit_decl_elt (tree_decl_elt&)
{
  fail ();
}

void
jit_infer::visit_decl_init_list (tree_decl_init_list&)
{
  fail ();
}

void
jit_infer::visit_simple_for_command (tree_simple_for_command& cmd)
{
  tree_expression *control = cmd.control_expr ();
  control->accept (*this);

  jit_type *control_t = type_stack.back ();
  type_stack.pop_back ();

  infer_simple_for (cmd, control_t);
}

void
jit_infer::visit_complex_for_command (tree_complex_for_command&)
{
  fail ();
}

void
jit_infer::visit_octave_user_script (octave_user_script&)
{
  fail ();
}

void
jit_infer::visit_octave_user_function (octave_user_function&)
{
  fail ();
}

void
jit_infer::visit_octave_user_function_header (octave_user_function&)
{
  fail ();
}

void
jit_infer::visit_octave_user_function_trailer (octave_user_function&)
{
  fail ();
}

void
jit_infer::visit_function_def (tree_function_def&)
{
  fail ();
}

void
jit_infer::visit_identifier (tree_identifier& ti)
{
  handle_identifier (ti.name (), ti.do_lookup ());
}

void
jit_infer::visit_if_clause (tree_if_clause&)
{
  fail ();
}

void
jit_infer::visit_if_command (tree_if_command&)
{
  fail ();
}

void
jit_infer::visit_if_command_list (tree_if_command_list&)
{
  fail ();
}

void
jit_infer::visit_index_expression (tree_index_expression&)
{
  fail ();
}

void
jit_infer::visit_matrix (tree_matrix&)
{
  fail ();
}

void
jit_infer::visit_cell (tree_cell&)
{
  fail ();
}

void
jit_infer::visit_multi_assignment (tree_multi_assignment&)
{
  fail ();
}

void
jit_infer::visit_no_op_command (tree_no_op_command&)
{
  fail ();
}

void
jit_infer::visit_constant (tree_constant& tc)
{
  if (is_lvalue)
    fail ();

  octave_value v = tc.rvalue1 ();
  jit_type *type = tinfo->type_of (v);
  if (! type)
    fail ();

  type_stack.push_back (type);
}

void
jit_infer::visit_fcn_handle (tree_fcn_handle&)
{
  fail ();
}

void
jit_infer::visit_parameter_list (tree_parameter_list&)
{
  fail ();
}

void
jit_infer::visit_postfix_expression (tree_postfix_expression&)
{
  fail ();
}

void
jit_infer::visit_prefix_expression (tree_prefix_expression&)
{
  fail ();
}

void
jit_infer::visit_return_command (tree_return_command&)
{
  fail ();
}

void
jit_infer::visit_return_list (tree_return_list&)
{
  fail ();
}

void
jit_infer::visit_simple_assignment (tree_simple_assignment& tsa)
{
  if (is_lvalue)
    fail ();

  // resolve rhs
  is_lvalue = false;
  tree_expression *rhs = tsa.right_hand_side ();
  rhs->accept (*this);

  jit_type *trhs = type_stack.back ();
  type_stack.pop_back ();

  // resolve lhs
  is_lvalue = true;
  rvalue_type = trhs;
  tree_expression *lhs = tsa.left_hand_side ();
  lhs->accept (*this);

  // we don't pop back here, as the resulting type should be the rhs type
  // which is equal to the lhs type anways
  jit_type *tlhs = type_stack.back ();
  if (tlhs != trhs)
    fail ();

  is_lvalue = false;
  rvalue_type = 0;
}

void
jit_infer::visit_statement (tree_statement& stmt)
{
  if (is_lvalue)
    fail ();

  tree_command *cmd = stmt.command ();
  tree_expression *expr = stmt.expression ();

  if (cmd)
    cmd->accept (*this);
  else
    {
      // ok, this check for ans appears three times as cp
      bool do_bind_ans = false;

      if (expr->is_identifier ())
        {
          tree_identifier *id = dynamic_cast<tree_identifier *> (expr);

          do_bind_ans = (! id->is_variable ());
        }
      else
        do_bind_ans = (! expr->is_assignment_expression ());

      expr->accept (*this);

      if (do_bind_ans)
        {
          is_lvalue = true;
          rvalue_type = type_stack.back ();
          type_stack.pop_back ();
          handle_identifier ("ans", symbol_table::varval ("ans"));

          if (rvalue_type != type_stack.back ())
            fail ();

          is_lvalue = false;
          rvalue_type = 0;
        }

      type_stack.pop_back ();
    }
}

void
jit_infer::visit_statement_list (tree_statement_list& lst)
{
  tree_statement_list::iterator iter;
  for (iter = lst.begin (); iter != lst.end (); ++iter)
    {
      tree_statement *stmt = *iter;
      assert (stmt); // FIXME: jwe can this be null?
      stmt->accept (*this);
    }
}

void
jit_infer::visit_switch_case (tree_switch_case&)
{
  fail ();
}

void
jit_infer::visit_switch_case_list (tree_switch_case_list&)
{
  fail ();
}

void
jit_infer::visit_switch_command (tree_switch_command&)
{
  fail ();
}

void
jit_infer::visit_try_catch_command (tree_try_catch_command&)
{
  fail ();
}

void
jit_infer::visit_unwind_protect_command (tree_unwind_protect_command&)
{
  fail ();
}

void
jit_infer::visit_while_command (tree_while_command&)
{
  fail ();
}

void
jit_infer::visit_do_until_command (tree_do_until_command&)
{
  fail ();
}

void
jit_infer::infer_simple_for (tree_simple_for_command& cmd,
                             jit_type *bounds)
{
  if (is_lvalue)
    fail ();

  jit_type *iter = tinfo->get_simple_for_index_result (bounds);
  if (! iter)
    fail ();

  is_lvalue = true;
  rvalue_type = iter;
  tree_expression *lhs = cmd.left_hand_side ();
  lhs->accept (*this);
  if (type_stack.back () != iter)
    fail ();
  type_stack.pop_back ();
  is_lvalue = false;
  rvalue_type = 0;

  tree_statement_list *body = cmd.body ();
  body->accept (*this);
}

void
jit_infer::handle_identifier (const std::string& name, octave_value v)
{
  type_map::iterator iter = types.find (name);
  if (iter == types.end ())
    {
      jit_type *ty = tinfo->type_of (v);
      if (is_lvalue)
        {
          if (! ty)
            ty = rvalue_type;
        }
      else
        {
          if (! ty)
            fail ();

          argin.insert (name);
        }

      types[name] = ty;
      type_stack.push_back (ty);
    }
  else
    type_stack.push_back (iter->second);
}

// -------------------- jit_generator --------------------
jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee,
                              const std::set<std::string>& argin,
                              const type_map& infered_types, bool have_bounds)
  : tinfo (ti), is_lvalue (false)
{
  // determine the function type through the type of all variables
  std::vector<llvm::Type *> arg_types (infered_types.size ());
  size_t idx = 0;
  type_map::const_iterator iter;
  for (iter = infered_types.begin (); iter != infered_types.end (); ++iter, ++idx)
    arg_types[idx] = iter->second->to_llvm_arg ();

  // now create the LLVM function from our determined types
  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
  llvm::Type *tvoid = llvm::Type::getVoidTy (ctx);
  llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false);
  function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                                     "foobar", module);

  // declare each argument and copy its initial value
  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function);
  builder.SetInsertPoint (body);
  llvm::Function::arg_iterator arg_iter = function->arg_begin();
  for (iter = infered_types.begin (); iter != infered_types.end ();
       ++iter, ++arg_iter)

    {
      llvm::Type *vartype = iter->second->to_llvm ();
      llvm::Value *var = builder.CreateAlloca (vartype, 0, iter->first);
      variables[iter->first] = value (iter->second, var);

      if (iter->second->force_init () || argin.count (iter->first))
        {
          llvm::Value *loaded_arg = builder.CreateLoad (arg_iter);
          builder.CreateStore (loaded_arg, var);
        }
    }

  // generate body
  try
    {
      tree_simple_for_command *cmd = dynamic_cast<tree_simple_for_command*>(&tee);
      if (have_bounds && cmd)
        {
          value bounds = variables["#bounds"];
          bounds.second = builder.CreateLoad (bounds.second);
          emit_simple_for (*cmd, bounds, true);
        }
      else
        tee.accept (*this);
    }
  catch (const jit_fail_exception&)
    {
      function->eraseFromParent ();
      function = 0;
      return;
    }

  // copy computed values back into arguments
  arg_iter = function->arg_begin ();
  for (iter = infered_types.begin (); iter != infered_types.end ();
       ++iter, ++arg_iter)
    {
      llvm::Value *var = variables[iter->first].second;
      llvm::Value *loaded_var = builder.CreateLoad (var);
      builder.CreateStore (loaded_var, arg_iter);
    }
  builder.CreateRetVoid ();
}

void
jit_generator::visit_anon_fcn_handle (tree_anon_fcn_handle&)
{
  fail ();
}

void
jit_generator::visit_argument_list (tree_argument_list&)
{
  fail ();
}

void
jit_generator::visit_binary_expression (tree_binary_expression& be)
{
  tree_expression *lhs = be.lhs ();
  lhs->accept (*this);
  value lhsv = value_stack.back ();
  value_stack.pop_back ();

  tree_expression *rhs = be.rhs ();
  rhs->accept (*this);
  value rhsv = value_stack.back ();
  value_stack.pop_back ();

  const jit_function::overload& ol
    = tinfo->binary_op_overload (be.op_type (), lhsv.first, rhsv.first);

  if (! ol.function)
    fail ();

  llvm::Value *result = builder.CreateCall2 (ol.function, lhsv.second,
                                             rhsv.second);
  push_value (ol.result, result);
}

void
jit_generator::visit_break_command (tree_break_command&)
{
  fail ();
}

void
jit_generator::visit_colon_expression (tree_colon_expression&)
{
  fail ();
}

void
jit_generator::visit_continue_command (tree_continue_command&)
{
  fail ();
}

void
jit_generator::visit_global_command (tree_global_command&)
{
  fail ();
}

void
jit_generator::visit_persistent_command (tree_persistent_command&)
{
  fail ();
}

void
jit_generator::visit_decl_elt (tree_decl_elt&)
{
  fail ();
}

void
jit_generator::visit_decl_init_list (tree_decl_init_list&)
{
  fail ();
}

void
jit_generator::visit_simple_for_command (tree_simple_for_command& cmd)
{
  if (is_lvalue)
    fail ();

  tree_expression *control = cmd.control_expr ();
  assert (control); // FIXME: jwe, can this be null?

  control->accept (*this);
  value over = value_stack.back ();
  value_stack.pop_back ();

  emit_simple_for (cmd, over, false);
}

void
jit_generator::visit_complex_for_command (tree_complex_for_command&)
{
  fail ();
}

void
jit_generator::visit_octave_user_script (octave_user_script&)
{
  fail ();
}

void
jit_generator::visit_octave_user_function (octave_user_function&)
{
  fail ();
}

void
jit_generator::visit_octave_user_function_header (octave_user_function&)
{
  fail ();
}

void
jit_generator::visit_octave_user_function_trailer (octave_user_function&)
{
  fail ();
}

void
jit_generator::visit_function_def (tree_function_def&)
{
  fail ();
}

void
jit_generator::visit_identifier (tree_identifier& ti)
{
  std::string name = ti.name ();
  value variable = variables[name];
  if (is_lvalue)
    value_stack.push_back (variable);
  else
    {
      llvm::Value *load = builder.CreateLoad (variable.second, name);
      push_value (variable.first, load);
    }
}

void
jit_generator::visit_if_clause (tree_if_clause&)
{
  fail ();
}

void
jit_generator::visit_if_command (tree_if_command&)
{
  fail ();
}

void
jit_generator::visit_if_command_list (tree_if_command_list&)
{
  fail ();
}

void
jit_generator::visit_index_expression (tree_index_expression&)
{
  fail ();
}

void
jit_generator::visit_matrix (tree_matrix&)
{
  fail ();
}

void
jit_generator::visit_cell (tree_cell&)
{
  fail ();
}

void
jit_generator::visit_multi_assignment (tree_multi_assignment&)
{
  fail ();
}

void
jit_generator::visit_no_op_command (tree_no_op_command&)
{
  fail ();
}

void
jit_generator::visit_constant (tree_constant& tc)
{
  octave_value v = tc.rvalue1 ();
  llvm::LLVMContext& ctx = llvm::getGlobalContext ();
  if (v.is_real_scalar () && v.is_double_type ())
    {
      double dv = v.double_value ();
      llvm::Value *lv = llvm::ConstantFP::get (ctx, llvm::APFloat (dv));
      push_value (tinfo->get_scalar (), lv);
    }
  else if (v.is_range ())
    {
      Range rng = v.range_value ();
      llvm::Type *range = tinfo->get_range_llvm ();
      llvm::Type *scalar = tinfo->get_scalar_llvm ();
      llvm::Type *index = tinfo->get_index_llvm ();

      std::vector<llvm::Constant *> values (4);
      values[0] = llvm::ConstantFP::get (scalar, rng.base ());
      values[1] = llvm::ConstantFP::get (scalar, rng.limit ());
      values[2] = llvm::ConstantFP::get (scalar, rng.inc ());
      values[3] = llvm::ConstantInt::get (index, rng.nelem ());

      llvm::StructType *llvm_range = llvm::cast<llvm::StructType>(range);
      llvm::Value *lv = llvm::ConstantStruct::get (llvm_range, values);
      push_value (tinfo->get_range (), lv);
    }
  else
    fail ();
}

void
jit_generator::visit_fcn_handle (tree_fcn_handle&)
{
  fail ();
}

void
jit_generator::visit_parameter_list (tree_parameter_list&)
{
  fail ();
}

void
jit_generator::visit_postfix_expression (tree_postfix_expression&)
{
  fail ();
}

void
jit_generator::visit_prefix_expression (tree_prefix_expression&)
{
  fail ();
}

void
jit_generator::visit_return_command (tree_return_command&)
{
  fail ();
}

void
jit_generator::visit_return_list (tree_return_list&)
{
  fail ();
}

void
jit_generator::visit_simple_assignment (tree_simple_assignment& tsa)
{
  if (is_lvalue)
    fail ();

  // resolve lhs
  is_lvalue = true;
  tree_expression *lhs = tsa.left_hand_side ();
  lhs->accept (*this);

  value lhsv = value_stack.back ();
  value_stack.pop_back ();

  // resolve rhs
  is_lvalue = false;
  tree_expression *rhs = tsa.right_hand_side ();
  rhs->accept (*this);

  value rhsv = value_stack.back ();
  value_stack.pop_back ();

  // do assign, then store rhs as the result
  jit_function::overload ol = tinfo->assign_op (lhsv.first, rhsv.first);
  builder.CreateCall2 (ol.function, lhsv.second, rhsv.second);

  if (tsa.print_result ())
    emit_print (lhs->name (), rhsv);

  value_stack.push_back (rhsv);
}

void
jit_generator::visit_statement (tree_statement& stmt)
{
  tree_command *cmd = stmt.command ();
  tree_expression *expr = stmt.expression ();

  if (cmd)
    cmd->accept (*this);
  else
    {
      // stolen from tree_evaluator::visit_statement
      bool do_bind_ans = false;

      if (expr->is_identifier ())
        {
          tree_identifier *id = dynamic_cast<tree_identifier *> (expr);

          do_bind_ans = (! id->is_variable ());
        }
      else
        do_bind_ans = (! expr->is_assignment_expression ());

      expr->accept (*this);

      if (do_bind_ans)
        {
          value rhs = value_stack.back ();
          value ans = variables["ans"];
          if (ans.first != rhs.first)
            fail ();

          builder.CreateStore (rhs.second, ans.second);

          if (expr->print_result ())
            emit_print ("ans", rhs);
        }
      else if (expr->is_identifier () && expr->print_result ())
        {
          // FIXME: ugly hack, we need to come up with a way to pass
          // nargout to visit_identifier
          emit_print (expr->name (), value_stack.back ());
        }


      value_stack.pop_back ();
    }
}

void
jit_generator::visit_statement_list (tree_statement_list& lst)
{
  tree_statement_list::iterator iter;
  for (iter = lst.begin (); iter != lst.end (); ++iter)
    {
      tree_statement *stmt = *iter;
      assert (stmt); // FIXME: jwe can this be null?
      stmt->accept (*this);
    }
}

void
jit_generator::visit_switch_case (tree_switch_case&)
{
  fail ();
}

void
jit_generator::visit_switch_case_list (tree_switch_case_list&)
{
  fail ();
}

void
jit_generator::visit_switch_command (tree_switch_command&)
{
  fail ();
}

void
jit_generator::visit_try_catch_command (tree_try_catch_command&)
{
  fail ();
}

void
jit_generator::visit_unwind_protect_command (tree_unwind_protect_command&)
{
  fail ();
}

void
jit_generator::visit_while_command (tree_while_command&)
{
  fail ();
}

void
jit_generator::visit_do_until_command (tree_do_until_command&)
{
  fail ();
}

void
jit_generator::emit_simple_for (tree_simple_for_command& cmd, value over,
                                bool atleast_once)
{
  if (is_lvalue)
    fail ();

  jit_type *index = tinfo->get_index ();
  llvm::Value *init_index = 0;
  if (over.first == tinfo->get_range ())
    init_index = llvm::ConstantInt::get (index->to_llvm (), 0);
  else
    fail ();

  llvm::Value *llvm_index = builder.CreateAlloca (index->to_llvm (), 0, "index");
  builder.CreateStore (init_index, llvm_index);

  // FIXME: Support break
  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "for_body", function);
  llvm::BasicBlock *cond_check = llvm::BasicBlock::Create (ctx, "for_check", function);
  llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "for_tail", function);

  // initialize the iter from the index
  if (atleast_once)
    builder.CreateBr (body);
  else
    builder.CreateBr (cond_check);

  builder.SetInsertPoint (body);

  is_lvalue = true;
  tree_expression *lhs = cmd.left_hand_side ();
  lhs->accept (*this);
  is_lvalue = false;

  value lhsv = value_stack.back ();
  value_stack.pop_back ();

  const jit_function::overload& index_ol = tinfo->get_simple_for_index (over.first);
  llvm::Value *lindex = builder.CreateLoad (llvm_index);
  llvm::Value *llvm_iter = builder.CreateCall2 (index_ol.function, over.second, lindex);
  value iter(index_ol.result, llvm_iter);

  jit_function::overload assign = tinfo->assign_op (lhsv.first, iter.first);
  builder.CreateCall2 (assign.function, lhsv.second, iter.second);

  tree_statement_list *lst = cmd.body ();
  lst->accept (*this);

  llvm::Value *one = llvm::ConstantInt::get (index->to_llvm (), 1);
  lindex = builder.CreateLoad (llvm_index);
  lindex = builder.CreateAdd (lindex, one);
  builder.CreateStore (lindex, llvm_index);
  builder.CreateBr (cond_check);

  builder.SetInsertPoint (cond_check);
  lindex = builder.CreateLoad (llvm_index);
  const jit_function::overload& check_ol = tinfo->get_simple_for_check (over.first);
  llvm::Value *cond = builder.CreateCall2 (check_ol.function, over.second, lindex);
  builder.CreateCondBr (cond, body, tail);

  builder.SetInsertPoint (tail);
}

void
jit_generator::emit_print (const std::string& name, const value& v)
{
  const jit_function::overload& ol = tinfo->print_value (v.first);
  if (! ol.function)
    fail ();

  llvm::Value *str = builder.CreateGlobalStringPtr (name);
  builder.CreateCall2 (ol.function, str, v.second);
}

// -------------------- tree_jit --------------------

tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0)
{
  llvm::InitializeNativeTarget ();
  module = new llvm::Module ("octave", context);
}

tree_jit::~tree_jit (void)
{
  delete tinfo;
}

bool
tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds)
{
  if (! initialize ())
    return false;

  jit_type *bounds_t = tinfo->type_of (bounds);
  jit_info *jinfo = cmd.get_info (bounds_t);
  if (! jinfo)
    {
      jinfo = new jit_info (*this, cmd, bounds_t);
      cmd.stash_info (bounds_t, jinfo);
    }

  return jinfo->execute (bounds);
}

bool
tree_jit::initialize (void)
{
  if (engine)
    return true;

  // sometimes this fails pre main
  engine = llvm::ExecutionEngine::createJIT (module);

  if (! engine)
    return false;

  module_pass_manager = new llvm::PassManager ();
  module_pass_manager->add (llvm::createAlwaysInlinerPass ());

  pass_manager = new llvm::FunctionPassManager (module);
  pass_manager->add (new llvm::TargetData(*engine->getTargetData ()));
  pass_manager->add (llvm::createBasicAliasAnalysisPass ());
  pass_manager->add (llvm::createPromoteMemoryToRegisterPass ());
  pass_manager->add (llvm::createInstructionCombiningPass ());
  pass_manager->add (llvm::createReassociatePass ());
  pass_manager->add (llvm::createGVNPass ());
  pass_manager->add (llvm::createCFGSimplificationPass ());
  pass_manager->doInitialization ();

  tinfo = new jit_typeinfo (module, engine);

  return true;
}


void
tree_jit::optimize (llvm::Function *fn)
{
  module_pass_manager->run (*module);
  pass_manager->run (*fn);
}

// -------------------- jit_info --------------------
jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd,
                    jit_type *bounds) : tinfo (tjit.get_typeinfo ()),
                                        engine (tjit.get_engine ())
{
  jit_infer infer(tinfo);

  try
    {
      infer.infer (cmd, bounds);
    }
  catch (const jit_fail_exception&)
    {
      function = 0;
      return;
    }

  argin = infer.get_argin ();
  types = infer.get_types ();

  jit_generator gen(tinfo, tjit.get_module (), cmd, argin, types);
  function = gen.get_function ();

  if (function)
    {
      if (debug_print)
        {
          std::cout << "Compiled code:\n";
          std::cout << cmd.str_print_code () << std::endl;

          std::cout << "Before optimization:\n";

          llvm::raw_os_ostream os (std::cout);
          function->print (os);
        }
      llvm::verifyFunction (*function);
      tjit.optimize (function);

      if (debug_print)
        {
          std::cout << "After optimization:\n";

          llvm::raw_os_ostream os (std::cout);
          function->print (os);
        }
    }
}

bool
jit_info::execute (const octave_value& bounds) const
{
  if (! function)
    return false;

  std::vector<llvm::GenericValue> args (types.size ());
  size_t idx;
  type_map::const_iterator iter;
  for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx)
    {
      if (argin.count (iter->first))
        {
          octave_value ov;
          if (iter->first == "#bounds")
            ov = bounds;
          else
            ov = symbol_table::varval (iter->first);

          tinfo->to_generic (iter->second, args[idx], ov);
        }
      else
        tinfo->to_generic (iter->second, args[idx]);
    }

  engine->runFunction (function, args);

  for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx)
    {
      octave_value result = tinfo->to_octave_value (iter->second, args[idx]);
      symbol_table::varref (iter->first) = result;
    }

  tinfo->reset_generic ();

  return true;
}

bool
jit_info::match () const
{
  for (std::set<std::string>::iterator iter = argin.begin ();
       iter != argin.end (); ++iter)
    {
      if (*iter == "#bounds")
        continue;

      jit_type *required_type = types.find (*iter)->second;
      octave_value val = symbol_table::varref (*iter);
      jit_type *current_type = tinfo->type_of (val);

      // FIXME: should be: ! required_type->is_parent (current_type)
      if (required_type != current_type)
        return false;
    }

  return true;
}