changeset 15337:3f43e9d6d86e

JIT compile anonymous functions * jit-ir.h (jit_block::front, jit_block::back): New function. (jit_call::jit_call): New overloads. (jit_return): New class. * jit-typeinfo.cc (octave_jit_create_undef): New function. (jit_operation::to_idx): Correctly handle empty type vector. (jit_typeinfo::jit_typeinfo): Add destroy_fn and initialize create_undef. * jit-typeinfo.h (jit_typeinfo::get_any_ptr, jit_typeinfo::destroy, jit_typeinfo::create_undef): New function. * pt-jit.cc (jit_convert::jit_convert): Add overload and refactor. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute, jit_function_info::jit_function_info, jit_function_info::execute, jit_function_info::match): New function. (jit_convert::get_variable): Support function variable lookup. (jit_convert_llvm::convert): Handle loop/function agnostic stuff. (jit_convert_llvm::visit): Handle function creation as well. (tree_jit::execute): Move implementation to tree_jit::do_execute. (jit_info::compile): Call convert_loop instead of convert. * pt-jit.h (jit_convert::jit_convert): New overload. (jit_convert::initialize, jit_convert_llvm::convert_loop, jit_convert_llvm::convert_function, tree_jit::do_execute): New function. (jit_convert::create_variable, jit_convert_llvm::initialize): Update signature. (tree_jit::execute): Made static. (tree_jit::tree_jit): Made private. (jit_function_info): New class. * ov-usr-fcn.cc (octave_user_function::~octave_user_function): Delete jit_info. (octave_user_function::octave_user_function): Maybe JIT and use is_special_expr and special_expr. (octave_user_function::special_expr): New function. * ov-usr-fcn.h (octave_user_function::is_special_expr, octave_user_function::special_expr, octave_user_function::get_info, octave_user_function::stash_info): New function. * pt-decl.h (tree_decl_elt::name): New function. * pt-eval.cc (tree_evaluator::visit_simple_for_command, tree_evaluator::visit_while_command): Use static tree_jit methods.
author Max Brister <max@2bass.com>
date Sun, 09 Sep 2012 00:29:00 -0600
parents 5fff79162342
children dc39c1d84c5b
files libinterp/interp-core/jit-ir.h libinterp/interp-core/jit-typeinfo.cc libinterp/interp-core/jit-typeinfo.h libinterp/interp-core/pt-jit.cc libinterp/interp-core/pt-jit.h libinterp/octave-value/ov-usr-fcn.cc libinterp/octave-value/ov-usr-fcn.h libinterp/parse-tree/pt-decl.h libinterp/parse-tree/pt-eval.cc
diffstat 9 files changed, 712 insertions(+), 126 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/interp-core/jit-ir.h	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/interp-core/jit-ir.h	Sun Sep 09 00:29:00 2012 -0600
@@ -42,6 +42,7 @@
   JIT_METH(call);                               \
   JIT_METH(extract_argument);                   \
   JIT_METH(store_argument);                     \
+  JIT_METH(return);                             \
   JIT_METH(phi);                                \
   JIT_METH(variable);                           \
   JIT_METH(error_check);                        \
@@ -768,6 +769,10 @@
     return true;
   }
 
+  jit_instruction *front (void) { return instructions.front (); }
+
+  jit_instruction *back (void) { return instructions.back (); }
+
   JIT_VALUE_ACCEPT;
 private:
   void internal_append (jit_instruction *instr);
@@ -1149,6 +1154,21 @@
 jit_call : public jit_instruction
 {
 public:
+  jit_call (const jit_operation& (*aoperation) (void))
+    : moperation (aoperation ())
+  {
+    const jit_function& ol = overload ();
+    if (ol.valid ())
+      stash_type (ol.result ());
+  }
+
+  jit_call (const jit_operation& aoperation) : moperation (aoperation)
+  {
+    const jit_function& ol = overload ();
+    if (ol.valid ())
+      stash_type (ol.result ());
+  }
+
 #define JIT_CALL_CONST(N)                                               \
   jit_call (const jit_operation& aoperation,                            \
             OCT_MAKE_DECL_LIST (jit_value *, arg, N))                   \
@@ -1366,6 +1386,38 @@
 };
 
 class
+jit_return : public jit_instruction
+{
+public:
+  jit_return (void) {}
+
+  jit_return (jit_value *retval) : jit_instruction (retval) {}
+
+  jit_value *result (void) const
+  {
+    return argument_count () ? argument (0) : 0;
+  }
+
+  jit_type *result_type (void) const
+  {
+    jit_value *res = result ();
+    return res ? res->type () : 0;
+  }
+
+  virtual std::ostream& print (std::ostream& os, size_t indent = 0) const
+  {
+    print_indent (os, indent) << "return";
+
+    if (result ())
+      os << " " << *result ();
+
+    return os;
+  }
+
+  JIT_VALUE_ACCEPT;
+};
+
+class
 jit_ir_walker
 {
 public:
--- a/libinterp/interp-core/jit-typeinfo.cc	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/interp-core/jit-typeinfo.cc	Sun Sep 09 00:29:00 2012 -0600
@@ -366,6 +366,16 @@
     return idx < ndim ? mat->dimensions[idx] : 1;
 }
 
+extern "C" octave_base_value *
+octave_jit_create_undef (void)
+{
+  octave_value undef;
+  octave_base_value *ret = undef.internal_rep ();
+  ret->grab ();
+
+  return ret;
+}
+
 extern "C" Complex
 octave_jit_complex_div (Complex lhs, Complex rhs)
 {
@@ -791,14 +801,15 @@
 jit_operation::to_idx (const std::vector<jit_type*>& types) const
 {
   octave_idx_type numel = types.size ();
-  if (numel == 1)
-    numel = 2;
+  numel = std::max (2, numel);
 
   Array<octave_idx_type> idx (dim_vector (1, numel));
   for (octave_idx_type i = 0; i < static_cast<octave_idx_type> (types.size ());
        ++i)
     idx(i) = types[i]->type_id ();
 
+  if (types.size () == 0)
+    idx(0) = idx(1) = 0;
   if (types.size () == 1)
     {
       idx(1) = idx(0);
@@ -1149,6 +1160,14 @@
   fn.add_mapping (engine, &octave_jit_release_matrix);
   release_fn.add_overload (fn);
 
+  // destroy
+  destroy_fn = release_fn;
+  destroy_fn.stash_name ("destroy");
+  destroy_fn.add_overload (create_identity(scalar));
+  destroy_fn.add_overload (create_identity(boolean));
+  destroy_fn.add_overload (create_identity(index));
+  destroy_fn.add_overload (create_identity(complex));
+
   // now for binary scalar operations
   add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd);
   add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub);
@@ -1702,6 +1721,12 @@
                         scalar, matrix, index, index);
   end_fn.add_overload (fn);
 
+  // -------------------- create_undef --------------------
+  create_undef_fn.stash_name ("create_undef");
+  fn = create_function (jit_convention::external, "octave_jit_create_undef",
+                        any);
+  create_undef_fn.add_overload (fn);
+
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
   casts[complex->type_id ()].stash_name ("(complex)");
--- a/libinterp/interp-core/jit-typeinfo.h	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/interp-core/jit-typeinfo.h	Sun Sep 09 00:29:00 2012 -0600
@@ -452,6 +452,8 @@
 
   static jit_type *get_scalar_ptr (void) { return instance->scalar_ptr; }
 
+  static jit_type *get_any_ptr (void) { return instance->any_ptr; }
+
   static jit_type *get_range (void) { return instance->range; }
 
   static jit_type *get_string (void) { return instance->string; }
@@ -498,6 +500,11 @@
     return instance->release_fn.overload (type);
   }
 
+  static const jit_operation& destroy (void)
+  {
+    return instance->destroy_fn;
+  }
+
   static const jit_operation& print_value (void)
   {
     return instance->print_fn;
@@ -563,6 +570,11 @@
   {
     return instance->do_end (value, index, count);
   }
+
+  static const jit_operation& create_undef (void)
+  {
+    return instance->create_undef_fn;
+  }
 private:
   jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e);
 
@@ -751,6 +763,7 @@
   std::vector<jit_operation> unary_ops;
   jit_operation grab_fn;
   jit_operation release_fn;
+  jit_operation destroy_fn;
   jit_operation print_fn;
   jit_operation for_init_fn;
   jit_operation for_check_fn;
@@ -761,6 +774,7 @@
   jit_paren_subsasgn paren_subsasgn_fn;
   jit_operation end1_fn;
   jit_operation end_fn;
+  jit_operation create_undef_fn;
 
   jit_function any_call;
 
--- a/libinterp/interp-core/pt-jit.cc	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/interp-core/pt-jit.cc	Sun Sep 09 00:29:00 2012 -0600
@@ -65,22 +65,16 @@
 
 // -------------------- jit_convert --------------------
 jit_convert::jit_convert (tree &tee, jit_type *for_bounds)
-  : iterator_count (0), for_bounds_count (0), short_count (0), breaking (false)
+  : converting_function (false)
 {
-  jit_instruction::reset_ids ();
-
-  entry_block = factory.create<jit_block> ("body");
-  final_block = factory.create<jit_block> ("final");
-  blocks.push_back (entry_block);
-  entry_block->mark_alive ();
-  block = entry_block;
+  initialize (symbol_table::current_scope ());
 
   if (for_bounds)
     create_variable (next_for_bounds (false), for_bounds);
 
   visit (tee);
 
-  // FIXME: Remove if we no longer only compile loops
+  // breaks must have been handled by the top level loop
   assert (! breaking);
   assert (breaks.empty ());
   assert (continues.empty ());
@@ -95,6 +89,91 @@
       if (name.size () && name[0] != '#')
         final_block->append (factory.create<jit_store_argument> (var));
     }
+
+  final_block->append (factory.create<jit_return> ());
+}
+
+jit_convert::jit_convert (octave_user_function& fcn,
+                          const std::vector<jit_type *>& args)
+  : converting_function (true)
+{
+  initialize (fcn.scope ());
+
+  tree_parameter_list *plist = fcn.parameter_list ();
+  tree_parameter_list *rlist = fcn.return_list ();
+  if (plist && plist->takes_varargs ())
+    throw jit_fail_exception ("varags not supported");
+
+  if (rlist && (rlist->size () > 1 || rlist->takes_varargs ()))
+    throw jit_fail_exception ("multiple returns not supported");
+
+  if (plist)
+    {
+      tree_parameter_list::iterator piter = plist->begin ();
+      for (size_t i = 0; i < args.size (); ++i, ++piter)
+        {
+          if (piter == plist->end ())
+            throw jit_fail_exception ("Too many parameter to function");
+
+          tree_decl_elt *elt = *piter;
+          std::string name = elt->name ();
+          create_variable (name, args[i]);
+        }
+    }
+
+  jit_value *return_value = 0;
+  if (fcn.is_special_expr ())
+    {
+      tree_expression *expr = fcn.special_expr ();
+      if (expr)
+        {
+          jit_variable *retvar = get_variable ("#return");
+          jit_value *retval = visit (expr);
+          block->append (factory.create<jit_assign> (retvar, retval));
+          return_value = retvar;
+        }
+    }
+  else
+    visit_statement_list (*fcn.body ());
+
+  // the user may use break or continue to exit the function. Because the
+  // function does not start as a loop, we can have one continue, one break, or
+  // a regular fallthrough to exit the function
+  if (continues.size ())
+    {
+      assert (! continues.size ());
+      finish_breaks (final_block, continues);
+    }
+  else if (breaks.size ())
+    finish_breaks (final_block, breaks);
+  else
+    block->append (factory.create<jit_branch> (final_block));
+  blocks.push_back (final_block);
+  block = final_block;
+
+  if (! return_value && rlist && rlist->size () == 1)
+    {
+      tree_decl_elt *elt = rlist->front ();
+      return_value = get_variable (elt->name ());
+    }
+
+  // FIXME: We should use live range analysis to delete variables where needed.
+  // For now we just delete everything at the end of the function.
+  for (variable_map::iterator iter = vmap.begin (); iter != vmap.end (); ++iter)
+    {
+      if (iter->second != return_value)
+        {
+          jit_call *call;
+          call = factory.create<jit_call> (&jit_typeinfo::destroy,
+                                           iter->second);
+          final_block->append (call);
+        }
+    }
+
+  if (return_value)
+    final_block->append (factory.create<jit_return> (return_value));
+  else
+    final_block->append (factory.create<jit_return> ());
 }
 
 void
@@ -719,6 +798,23 @@
   throw jit_fail_exception ();
 }
 
+void
+jit_convert::initialize (symbol_table::scope_id s)
+{
+  scope = s;
+  iterator_count = 0;
+  for_bounds_count = 0;
+  short_count = 0;
+  breaking = false;
+  jit_instruction::reset_ids ();
+
+  entry_block = factory.create<jit_block> ("body");
+  final_block = factory.create<jit_block> ("final");
+  blocks.push_back (entry_block);
+  entry_block->mark_alive ();
+  block = entry_block;
+}
+
 jit_call *
 jit_convert::create_checked_impl (jit_call *ret)
 {
@@ -749,20 +845,42 @@
   if (ret)
     return ret;
 
-  octave_value val = symbol_table::find (vname);
-  jit_type *type = jit_typeinfo::type_of (val);
-  bounds.push_back (type_bound (type, vname));
+  symbol_table::symbol_record record = symbol_table::find_symbol (vname, scope);
+  if (record.is_persistent () || record.is_global ())
+    throw jit_fail_exception ("Persistent and global not yet supported");
 
-  return create_variable (vname, type);
+  if (converting_function)
+    return create_variable (vname, jit_typeinfo::get_any (), false);
+  else
+    {
+      octave_value val = record.varval ();
+      jit_type *type = jit_typeinfo::type_of (val);
+      bounds.push_back (type_bound (type, vname));
+
+      return create_variable (vname, type);
+    }
 }
 
 jit_variable *
-jit_convert::create_variable (const std::string& vname, jit_type *type)
+jit_convert::create_variable (const std::string& vname, jit_type *type,
+                              bool isarg)
 {
   jit_variable *var = factory.create<jit_variable> (vname);
-  jit_extract_argument *extract;
-  extract = factory.create<jit_extract_argument> (type, var);
-  entry_block->prepend (extract);
+
+  if (isarg)
+    {
+      jit_extract_argument *extract;
+      extract = factory.create<jit_extract_argument> (type, var);
+      entry_block->prepend (extract);
+    }
+  else
+    {
+      jit_call *init = factory.create<jit_call> (&jit_typeinfo::create_undef);
+      jit_assign *assign = factory.create<jit_assign> (var, init);
+      entry_block->prepend (assign);
+      entry_block->prepend (init);
+    }
+
   return vmap[vname] = var;
 }
 
@@ -898,10 +1016,12 @@
 
 // -------------------- jit_convert_llvm --------------------
 llvm::Function *
-jit_convert_llvm::convert (llvm::Module *module,
-                           const jit_block_list& blocks,
-                           const std::list<jit_value *>& constants)
+jit_convert_llvm::convert_loop (llvm::Module *module,
+                                const jit_block_list& blocks,
+                                const std::list<jit_value *>& constants)
 {
+  converting_function = false;
+
   // for now just init arguments from entry, later we will have to do something
   // more interesting
   jit_block *entry_block = blocks.front ();
@@ -934,44 +1054,7 @@
           arguments[argument_vec[i].first] = loaded_arg;
         }
 
-      std::list<jit_block *>::const_iterator biter;
-      for (biter = blocks.begin (); biter != blocks.end (); ++biter)
-        {
-          jit_block *jblock = *biter;
-          llvm::BasicBlock *block = llvm::BasicBlock::Create (context,
-                                                              jblock->name (),
-                                                              function);
-          jblock->stash_llvm (block);
-        }
-
-      jit_block *first = *blocks.begin ();
-      builder.CreateBr (first->to_llvm ());
-
-      // constants aren't in the IR, we visit those first
-      for (std::list<jit_value *>::const_iterator iter = constants.begin ();
-           iter != constants.end (); ++iter)
-        if (! isa<jit_instruction> (*iter))
-          visit (*iter);
-
-      // convert all instructions
-      for (biter = blocks.begin (); biter != blocks.end (); ++biter)
-        visit (*biter);
-
-      // now finish phi nodes
-      for (biter = blocks.begin (); biter != blocks.end (); ++biter)
-        {
-          jit_block& block = **biter;
-          for (jit_block::iterator piter = block.begin ();
-               piter != block.end () && isa<jit_phi> (*piter); ++piter)
-            {
-              jit_instruction *phi = *piter;
-              finish_phi (static_cast<jit_phi *> (phi));
-            }
-        }
-
-      jit_block *last = blocks.back ();
-      builder.SetInsertPoint (last->to_llvm ());
-      builder.CreateRetVoid ();
+      convert (blocks, constants);
     } catch (const jit_fail_exception& e)
     {
       function->eraseFromParent ();
@@ -981,6 +1064,92 @@
   return function;
 }
 
+
+jit_function
+jit_convert_llvm::convert_function (llvm::Module *module,
+                                    const jit_block_list& blocks,
+                                    const std::list<jit_value *>& constants,
+                                    octave_user_function& fcn,
+                                    const std::vector<jit_type *>& args)
+{
+  converting_function = true;
+
+  jit_block *final_block = blocks.back ();
+  jit_return *ret = dynamic_cast<jit_return *> (final_block->back ());
+  assert (ret);
+
+  jit_function creating = jit_function (module, jit_convention::internal,
+                                        "foobar", ret->result_type (), args);
+  function = creating.to_llvm ();
+
+  try
+    {
+      prelude = creating.new_block ("prelude");
+      builder.SetInsertPoint (prelude);
+
+      tree_parameter_list *plist = fcn.parameter_list ();
+      if (plist)
+        {
+          tree_parameter_list::iterator piter = plist->begin ();
+          tree_parameter_list::iterator pend = plist->end ();
+          for (size_t i = 0; i < args.size () && piter != pend; ++i, ++piter)
+            {
+              tree_decl_elt *elt = *piter;
+              std::string arg_name = elt->name ();
+              arguments[arg_name] = creating.argument (builder, i);
+            }
+        }
+
+      convert (blocks, constants);
+    } catch (const jit_fail_exception& e)
+    {
+      function->eraseFromParent ();
+      throw;
+    }
+
+  return creating;
+}
+
+void
+jit_convert_llvm::convert (const jit_block_list& blocks,
+                           const std::list<jit_value *>& constants)
+{
+  std::list<jit_block *>::const_iterator biter;
+  for (biter = blocks.begin (); biter != blocks.end (); ++biter)
+    {
+      jit_block *jblock = *biter;
+      llvm::BasicBlock *block = llvm::BasicBlock::Create (context,
+                                                          jblock->name (),
+                                                          function);
+      jblock->stash_llvm (block);
+    }
+
+  jit_block *first = *blocks.begin ();
+  builder.CreateBr (first->to_llvm ());
+
+  // constants aren't in the IR, we visit those first
+  for (std::list<jit_value *>::const_iterator iter = constants.begin ();
+       iter != constants.end (); ++iter)
+    if (! isa<jit_instruction> (*iter))
+      visit (*iter);
+
+  // convert all instructions
+  for (biter = blocks.begin (); biter != blocks.end (); ++biter)
+    visit (*biter);
+
+  // now finish phi nodes
+  for (biter = blocks.begin (); biter != blocks.end (); ++biter)
+    {
+      jit_block& block = **biter;
+      for (jit_block::iterator piter = block.begin ();
+           piter != block.end () && isa<jit_phi> (*piter); ++piter)
+        {
+          jit_instruction *phi = *piter;
+          finish_phi (static_cast<jit_phi *> (phi));
+        }
+    }
+}
+
 void
 jit_convert_llvm::finish_phi (jit_phi *phi)
 {
@@ -1089,10 +1258,16 @@
 {
   llvm::Value *arg = arguments[extract.name ()];
   assert (arg);
-  arg = builder.CreateLoad (arg);
 
-  const jit_function& ol = extract.overload ();
-  extract.stash_llvm (ol.call (builder, arg));
+  if (converting_function)
+    extract.stash_llvm (arg);
+  else
+    {
+      arg = builder.CreateLoad (arg);
+
+      const jit_function& ol = extract.overload ();
+      extract.stash_llvm (ol.call (builder, arg));
+    }
 }
 
 void
@@ -1105,6 +1280,16 @@
 }
 
 void
+jit_convert_llvm::visit (jit_return& ret)
+{
+  jit_value *res = ret.result ();
+  if (res)
+    builder.CreateRet (res->to_llvm ());
+  else
+    builder.CreateRetVoid ();
+}
+
+void
 jit_convert_llvm::visit (jit_phi& phi)
 {
   // we might not have converted all incoming branches, so we don't
@@ -1539,44 +1724,27 @@
 bool
 tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds)
 {
-  const size_t MIN_TRIP_COUNT = 1000;
-
-  size_t tc = trip_count (bounds);
-  if (! tc || ! initialize ())
-    return false;
-
-  jit_info::vmap extra_vars;
-  extra_vars["#for_bounds0"] = &bounds;
-
-  jit_info *info = cmd.get_info ();
-  if (! info || ! info->match (extra_vars))
-    {
-      if (tc < MIN_TRIP_COUNT)
-        return false;
-
-      delete info;
-      info = new jit_info (*this, cmd, bounds);
-      cmd.stash_info (info);
-    }
-
-  return info->execute (extra_vars);
+  return instance ().do_execute (cmd, bounds);
 }
 
 bool
 tree_jit::execute (tree_while_command& cmd)
 {
-  if (! initialize ())
-    return false;
+  return instance ().do_execute (cmd);
+}
 
-  jit_info *info = cmd.get_info ();
-  if (! info || ! info->match ())
-    {
-      delete info;
-      info = new jit_info (*this, cmd);
-      cmd.stash_info (info);
-    }
+bool
+tree_jit::execute (octave_user_function& fcn, const octave_value_list& args,
+                   octave_value_list& retval)
+{
+  return instance ().do_execute (fcn, args, retval);
+}
 
-  return info->execute ();
+tree_jit&
+tree_jit::instance (void)
+{
+  static tree_jit ret;
+  return ret;
 }
 
 bool
@@ -1616,6 +1784,67 @@
   return true;
 }
 
+bool
+tree_jit::do_execute (tree_simple_for_command& cmd, const octave_value& bounds)
+{
+  const size_t MIN_TRIP_COUNT = 1000;
+
+  size_t tc = trip_count (bounds);
+  if (! tc || ! initialize ())
+    return false;
+
+  jit_info::vmap extra_vars;
+  extra_vars["#for_bounds0"] = &bounds;
+
+  jit_info *info = cmd.get_info ();
+  if (! info || ! info->match (extra_vars))
+    {
+      if (tc < MIN_TRIP_COUNT)
+        return false;
+
+      delete info;
+      info = new jit_info (*this, cmd, bounds);
+      cmd.stash_info (info);
+    }
+
+  return info->execute (extra_vars);
+}
+
+bool
+tree_jit::do_execute (tree_while_command& cmd)
+{
+  if (! initialize ())
+    return false;
+
+  jit_info *info = cmd.get_info ();
+  if (! info || ! info->match ())
+    {
+      delete info;
+      info = new jit_info (*this, cmd);
+      cmd.stash_info (info);
+    }
+
+  return info->execute ();
+}
+
+bool
+tree_jit::do_execute (octave_user_function& fcn, const octave_value_list& args,
+                      octave_value_list& retval)
+{
+  if (! initialize ())
+    return false;
+
+  jit_function_info *info = fcn.get_info ();
+    if (! info || ! info->match (args))
+      {
+        delete info;
+        info = new jit_function_info (*this, fcn, args);
+        fcn.stash_info (info);
+      }
+
+    return info->execute (args, retval);
+}
+
 size_t
 tree_jit::trip_count (const octave_value& bounds) const
 {
@@ -1644,6 +1873,163 @@
 #endif
 }
 
+// -------------------- jit_function_info --------------------
+jit_function_info::jit_function_info (tree_jit& tjit,
+                                      octave_user_function& fcn,
+                                      const octave_value_list& ov_args)
+  : argument_types (ov_args.length ()), function (0)
+{
+  size_t nargs = ov_args.length ();
+  for (size_t i = 0; i < nargs; ++i)
+    argument_types[i] = jit_typeinfo::type_of (ov_args(i));
+
+  try
+    {
+      jit_convert conv (fcn, argument_types);
+      jit_infer infer (conv.get_factory (), conv.get_blocks (),
+                       conv.get_variable_map ());
+      infer.infer ();
+
+#if OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          jit_block_list& blocks = infer.get_blocks ();
+          jit_block *entry_block = blocks.front ();
+          entry_block->label ();
+          std::cout << "-------------------- Compiling function ";
+          std::cout << "--------------------\n";
+
+          tree_print_code tpc (std::cout);
+          tpc.visit_octave_user_function_header (fcn);
+          tpc.visit_statement_list (*fcn.body ());
+          tpc.visit_octave_user_function_trailer (fcn);
+          blocks.print (std::cout, "octave jit ir");
+        }
+#endif
+
+      jit_factory& factory = conv.get_factory ();
+      llvm::Module *module = tjit.get_module ();
+      jit_convert_llvm to_llvm;
+      jit_function raw_fn = to_llvm.convert_function (module,
+                                                      infer.get_blocks (),
+                                                      factory.constants (),
+                                                      fcn, argument_types);
+
+#ifdef OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          std::cout << "-------------------- raw function ";
+          std::cout << "--------------------\n";
+          std::cout << *raw_fn.to_llvm () << std::endl;
+        }
+#endif
+
+      std::string wrapper_name = fcn.name () + "_wrapper";
+      jit_type *any_t = jit_typeinfo::get_any ();
+      std::vector<jit_type *> wrapper_args (1, jit_typeinfo::get_any_ptr ());
+      jit_function wrapper (module, jit_convention::internal, wrapper_name,
+                            any_t, wrapper_args);
+      llvm::BasicBlock *wrapper_body = wrapper.new_block ();
+      builder.SetInsertPoint (wrapper_body);
+
+      llvm::Value *wrapper_arg = wrapper.argument (builder, 0);
+      std::vector<llvm::Value *> raw_args (nargs);
+      for (size_t i = 0; i < nargs; ++i)
+        {
+          llvm::Value *arg;
+          arg = builder.CreateConstInBoundsGEP1_32 (wrapper_arg, i);
+          arg = builder.CreateLoad (arg);
+
+          jit_type *arg_type = argument_types[i];
+          const jit_function& cast = jit_typeinfo::cast (arg_type, any_t);
+          raw_args[i] = cast.call (builder, arg);
+        }
+
+      llvm::Value *result = raw_fn.call (builder, raw_args);
+      if (raw_fn.result ())
+        {
+          jit_type *raw_result_t = raw_fn.result ();
+          const jit_function& cast = jit_typeinfo::cast (any_t, raw_result_t);
+          result = cast.call (builder, result);
+        }
+      else
+        {
+          llvm::Value *zero = builder.getInt32 (0);
+          result = builder.CreateBitCast (zero, any_t->to_llvm ());
+        }
+
+      wrapper.do_return (builder, result);
+
+      llvm::Function *llvm_function = wrapper.to_llvm ();
+      tjit.optimize (llvm_function);
+
+#ifdef OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          std::cout << "-------------------- optimized and wrapped ";
+          std::cout << "--------------------\n";
+          std::cout << *llvm_function << std::endl;
+        }
+#endif
+
+      llvm::ExecutionEngine* engine = tjit.get_engine ();
+      void *void_fn = engine->getPointerToFunction (llvm_function);
+      function = reinterpret_cast<jited_function> (void_fn);
+    }
+  catch (const jit_fail_exception& e)
+    {
+      argument_types.clear ();
+#ifdef OCTAVE_JIT_DEBUG
+      if (Venable_jit_debug)
+        {
+          if (e.known ())
+            std::cout << "jit fail: " << e.what () << std::endl;
+        }
+#endif
+    }
+}
+
+bool
+jit_function_info::execute (const octave_value_list& ov_args,
+                            octave_value_list& retval) const
+{
+  if (! function)
+    return false;
+
+  // TODO figure out a way to delete ov_args so we avoid duplicating refcount
+  size_t nargs = ov_args.length ();
+  std::vector<octave_base_value *> args (nargs);
+  for (size_t i = 0; i < nargs; ++i)
+    {
+      octave_base_value *obv = ov_args(i).internal_rep ();
+      obv->grab ();
+      args[i] = obv;
+    }
+
+  octave_base_value *ret = function (&args[0]);
+  if (ret)
+    retval(0) = octave_value (ret);
+
+  return true;
+}
+
+bool
+jit_function_info::match (const octave_value_list& ov_args) const
+{
+  if (! function)
+    return true;
+
+  size_t nargs = ov_args.length ();
+  if (nargs != argument_types.size ())
+    return false;
+
+  for (size_t i = 0; i < nargs; ++i)
+    if (jit_typeinfo::type_of (ov_args(i)) != argument_types[i])
+      return false;
+
+  return true;
+}
+
 // -------------------- jit_info --------------------
 jit_info::jit_info (tree_jit& tjit, tree& tee)
   : engine (tjit.get_engine ()), function (0), llvm_function (0)
@@ -1739,8 +2125,9 @@
 
       jit_factory& factory = conv.get_factory ();
       jit_convert_llvm to_llvm;
-      llvm_function = to_llvm.convert (tjit.get_module (), infer.get_blocks (),
-                                       factory.constants ());
+      llvm_function = to_llvm.convert_loop (tjit.get_module (),
+                                            infer.get_blocks (),
+                                            factory.constants ());
       arguments = to_llvm.get_arguments ();
       bounds = conv.get_bounds ();
     }
@@ -2126,4 +2513,13 @@
 
 %!error <undefined near> (test_undef);
 
+%!shared id
+%! id = @(x) x;
+
+%!assert (id (1), 1);
+%!assert (id (1+1i), 1+1i)
+%!assert (id (1, 2), 1)
+%!error <undefined> (id ())
+
+
 */
--- a/libinterp/interp-core/pt-jit.h	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/interp-core/pt-jit.h	Sun Sep 09 00:29:00 2012 -0600
@@ -26,8 +26,10 @@
 #ifdef HAVE_LLVM
 
 #include "jit-ir.h"
+#include "pt-walk.h"
+#include "symtab.h"
 
-#include "pt-walk.h"
+class octave_value_list;
 
 // Convert from the parse tree (AST) to the low level Octave IR.
 class
@@ -40,6 +42,8 @@
 
   jit_convert (tree &tee, jit_type *for_bounds = 0);
 
+  jit_convert (octave_user_function& fcn, const std::vector<jit_type *>& args);
+
 #define DECL_ARG(n) const ARG ## n& arg ## n
 #define JIT_CREATE_CHECKED(N)                                           \
   template <OCT_MAKE_DECL_LIST (typename, ARG, N)>                      \
@@ -156,6 +160,11 @@
   std::vector<std::pair<std::string, bool> > arguments;
   type_bound_vector bounds;
 
+  bool converting_function;
+
+  // the scope of the function we are converting, or the current scope
+  symbol_table::scope_id scope;
+
   jit_factory factory;
 
   // used instead of return values from visit_* functions
@@ -179,6 +188,8 @@
 
   variable_map vmap;
 
+  void initialize (symbol_table::scope_id s);
+
   jit_call *create_checked_impl (jit_call *ret);
 
   // get an existing vairable. If the variable does not exist, it will not be
@@ -191,7 +202,8 @@
 
   // create a variable of the given name and given type. Will also insert an
   // extract statement
-  jit_variable *create_variable (const std::string& vname, jit_type *type);
+  jit_variable *create_variable (const std::string& vname, jit_type *type,
+                                 bool isarg = true);
 
   // The name of the next for loop iterator. If inc is false, then the iterator
   // counter will not be incremented.
@@ -233,10 +245,17 @@
 jit_convert_llvm : public jit_ir_walker
 {
 public:
-  llvm::Function *convert (llvm::Module *module,
-                           const jit_block_list& blocks,
-                           const std::list<jit_value *>& constants);
+  llvm::Function *convert_loop (llvm::Module *module,
+                                const jit_block_list& blocks,
+                                const std::list<jit_value *>& constants);
 
+  jit_function convert_function (llvm::Module *module,
+                                 const jit_block_list& blocks,
+                                 const std::list<jit_value *>& constants,
+                                 octave_user_function& fcn,
+                                 const std::vector<jit_type *>& args);
+
+  // arguments to the llvm::Function for loops
   const std::vector<std::pair<std::string, bool> >& get_arguments(void) const
   { return argument_vec; }
 
@@ -247,13 +266,22 @@
 
 #undef JIT_METH
 private:
+  // name -> argument index (used for compiling functions)
+  std::map<std::string, int> argument_index;
+
   std::vector<std::pair<std::string, bool> > argument_vec;
 
-  // name -> llvm argument
+  // name -> llvm argument (used for compiling loops)
   std::map<std::string, llvm::Value *> arguments;
+
+  bool converting_function;
+
   llvm::Function *function;
   llvm::BasicBlock *prelude;
 
+  void convert (const jit_block_list& blocks,
+                const std::list<jit_value *>& constants);
+
   void finish_phi (jit_phi *phi);
 
   void visit (jit_value *jvalue)
@@ -319,13 +347,15 @@
 tree_jit
 {
 public:
-  tree_jit (void);
-
   ~tree_jit (void);
 
-  bool execute (tree_simple_for_command& cmd, const octave_value& bounds);
+  static bool execute (tree_simple_for_command& cmd,
+                       const octave_value& bounds);
 
-  bool execute (tree_while_command& cmd);
+  static bool execute (tree_while_command& cmd);
+
+  static bool execute (octave_user_function& fcn, const octave_value_list& args,
+                       octave_value_list& retval);
 
   llvm::ExecutionEngine *get_engine (void) const { return engine; }
 
@@ -333,8 +363,19 @@
 
   void optimize (llvm::Function *fn);
  private:
+  tree_jit (void);
+
+  static tree_jit& instance (void);
+
   bool initialize (void);
 
+  bool do_execute (tree_simple_for_command& cmd, const octave_value& bounds);
+
+  bool do_execute (tree_while_command& cmd);
+
+  bool do_execute (octave_user_function& fcn, const octave_value_list& args,
+                   octave_value_list& retval);
+
   size_t trip_count (const octave_value& bounds) const;
 
   llvm::Module *module;
@@ -344,6 +385,24 @@
 };
 
 class
+jit_function_info
+{
+public:
+  jit_function_info (tree_jit& tjit, octave_user_function& fcn,
+                     const octave_value_list& ov_args);
+
+  bool execute (const octave_value_list& ov_args,
+                octave_value_list& retval) const;
+
+  bool match (const octave_value_list& ov_args) const;
+private:
+  typedef octave_base_value *(*jited_function)(octave_base_value**);
+
+  std::vector<jit_type *> argument_types;
+  jited_function function;
+};
+
+class
 jit_info
 {
 public:
--- a/libinterp/octave-value/ov-usr-fcn.cc	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/octave-value/ov-usr-fcn.cc	Sun Sep 09 00:29:00 2012 -0600
@@ -39,6 +39,7 @@
 #include "ov.h"
 #include "pager.h"
 #include "pt-eval.h"
+#include "pt-jit.h"
 #include "pt-jump.h"
 #include "pt-misc.h"
 #include "pt-pr-code.h"
@@ -192,6 +193,9 @@
     class_constructor (false), class_method (false),
     parent_scope (-1), local_scope (sid),
     curr_unwind_protect_frame (0)
+#ifdef HAVE_LLVM
+    , jit_info (0)
+#endif
 {
   if (cmd_list)
     cmd_list->mark_as_function_body ();
@@ -208,6 +212,10 @@
   delete lead_comm;
   delete trail_comm;
 
+#ifdef HAVE_LLVM
+  delete jit_info;
+#endif
+
   symbol_table::erase_scope (local_scope);
 }
 
@@ -372,6 +380,12 @@
   if (! cmd_list)
     return retval;
 
+#ifdef HAVE_LLVM
+  if (Venable_jit_compiler && is_special_expr ()
+      && tree_jit::execute (*this, args, retval))
+    return retval;
+#endif
+
   int nargin = args.length ();
 
   unwind_protect frame;
@@ -457,23 +471,14 @@
   frame.protect_var (tree_evaluator::statement_context);
   tree_evaluator::statement_context = tree_evaluator::function;
 
-  bool special_expr = (is_inline_function () || is_anonymous_function ());
-
   BEGIN_PROFILER_BLOCK (profiler_name ())
 
-  if (special_expr)
+  if (is_special_expr ())
     {
-      assert (cmd_list->length () == 1);
-
-      tree_statement *stmt = 0;
+      tree_expression *expr = special_expr ();
 
-      if ((stmt = cmd_list->front ())
-          && stmt->is_expression ())
-        {
-          tree_expression *expr = stmt->expression ();
-
-          retval = expr->rvalue (nargout);
-        }
+      if (expr)
+        retval = expr->rvalue (nargout);
     }
   else
     cmd_list->accept (*current_evaluator);
@@ -497,7 +502,7 @@
 
   // Copy return values out.
 
-  if (ret_list && ! special_expr)
+  if (ret_list && ! is_special_expr ())
     {
       ret_list->initialize_undefined_elements (my_name, nargout, Matrix ());
 
@@ -529,6 +534,16 @@
   tw.visit_octave_user_function (*this);
 }
 
+tree_expression *
+octave_user_function::special_expr (void)
+{
+  assert (is_special_expr ());
+  assert (cmd_list->length () == 1);
+
+  tree_statement *stmt = cmd_list->front ();
+  return stmt->expression ();
+}
+
 bool
 octave_user_function::subsasgn_optimization_ok (void)
 {
--- a/libinterp/octave-value/ov-usr-fcn.h	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/octave-value/ov-usr-fcn.h	Sun Sep 09 00:29:00 2012 -0600
@@ -41,8 +41,13 @@
 class tree_parameter_list;
 class tree_statement_list;
 class tree_va_return_list;
+class tree_expression;
 class tree_walker;
 
+#ifdef HAVE_LLVM
+class jit_function_info;
+#endif
+
 class
 octave_user_code : public octave_function
 {
@@ -283,6 +288,14 @@
       : false;
   }
 
+  // If we are a special expression, then the function body consists of exactly
+  // one expression. The expression's result is the return value of the
+  // function.
+  bool is_special_expr (void) const
+  {
+    return is_inline_function () || is_anonymous_function ();
+  }
+
   bool is_nested_function (void) const { return nested_function; }
 
   void mark_as_nested_function (void) { nested_function = true; }
@@ -335,6 +348,10 @@
 
   octave_comment_list *trailing_comment (void) { return trail_comm; }
 
+  // If is_special_expr is true, retrieve the sigular expression that forms the
+  // body. May be null (even if is_special_expr is true).
+  tree_expression *special_expr (void);
+
   bool subsasgn_optimization_ok (void);
 
   void accept (tree_walker& tw);
@@ -351,6 +368,12 @@
         return false;
     }
 
+#ifdef HAVE_LLVM
+  jit_function_info *get_info (void) { return jit_info; }
+
+  void stash_info (jit_function_info *info) { jit_info = info; }
+#endif
+
 #if 0
   void print_symtab_info (std::ostream& os) const;
 #endif
@@ -427,6 +450,10 @@
   // pointer to the current unwind_protect frame of this function.
   unwind_protect *curr_unwind_protect_frame;
 
+#ifdef HAVE_LLVM
+  jit_function_info *jit_info;
+#endif
+
 #if 0
   // The symbol record for argn in the local symbol table.
   octave_value& argn_varref;
--- a/libinterp/parse-tree/pt-decl.h	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/parse-tree/pt-decl.h	Sun Sep 09 00:29:00 2012 -0600
@@ -84,6 +84,8 @@
 
   tree_identifier *ident (void) { return id; }
 
+  std::string name (void) { return id ? id->name () : ""; }
+
   tree_expression *expression (void) { return expr; }
 
   tree_decl_elt *dup (symbol_table::scope_id scope,
--- a/libinterp/parse-tree/pt-eval.cc	Sat Sep 08 18:47:29 2012 -0700
+++ b/libinterp/parse-tree/pt-eval.cc	Sun Sep 09 00:29:00 2012 -0600
@@ -47,10 +47,6 @@
 //FIXME: This should be part of tree_evaluator
 #include "pt-jit.h"
 
-#if HAVE_LLVM
-static tree_jit jiter;
-#endif
-
 static tree_evaluator std_evaluator;
 
 tree_evaluator *current_evaluator = &std_evaluator;
@@ -311,7 +307,7 @@
   octave_value rhs = expr->rvalue1 ();
 
 #if HAVE_LLVM
-  if (Venable_jit_compiler && jiter.execute (cmd, rhs))
+  if (Venable_jit_compiler && tree_jit::execute (cmd, rhs))
     return;
 #endif
 
@@ -1048,7 +1044,7 @@
     return;
 
 #if HAVE_LLVM
-  if (Venable_jit_compiler && jiter.execute (cmd))
+  if (Venable_jit_compiler && tree_jit::execute (cmd))
     return;
 #endif