changeset 14915:cba58541954c

Add if support and fix leak with any
author Max Brister <max@2bass.com>
date Mon, 21 May 2012 15:41:19 -0600
parents 1e5eafcb83f8
children 0b0569667939
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 381 insertions(+), 148 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc	Fri May 18 10:24:14 2012 -0600
+++ b/src/pt-jit.cc	Mon May 21 15:41:19 2012 -0600
@@ -88,8 +88,8 @@
 octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs,
                            octave_base_value *rhs)
 {
-  octave_value olhs (lhs, true);
-  octave_value orhs (rhs, true);
+  octave_value olhs (lhs);
+  octave_value orhs (rhs);
   octave_value result = do_binary_op (op, olhs, orhs);
   octave_base_value *rep = result.internal_rep ();
   rep->grab ();
@@ -97,13 +97,15 @@
 }
 
 extern "C" void
-octave_jit_assign_any_any_help (octave_base_value *lhs, octave_base_value *rhs)
+octave_jit_release_any (octave_base_value *obv)
 {
-  if (lhs != rhs)
-    {
-      rhs->grab ();
-      lhs->release ();
-    }
+  obv->release ();
+}
+
+extern "C" void
+octave_jit_grab_any (octave_base_value *obv)
+{
+  obv->grab ();
 }
 
 // -------------------- jit_type --------------------
@@ -222,33 +224,21 @@
   index = new_type ("index", false, any, index_t);
 
   // any with anything is an any op
+  llvm::Function *fn;
   llvm::Type *binary_op_type
     = llvm::Type::getIntNTy (ctx, sizeof (octave_value::binary_op));
-  std::vector<llvm::Type*> args (3);
-  args[0] = binary_op_type;
-  args[1] = args[2] = any->to_llvm ();
-  llvm::FunctionType *any_binary_t = llvm::FunctionType::get (ov_t, args, false);
-  llvm::Function *any_binary = llvm::Function::Create (any_binary_t,
-                                                       llvm::Function::ExternalLinkage,
-                                                       "octave_jit_binary_any_any",
-                                                       module);
+  llvm::Function *any_binary = create_function ("octave_jit_binary_any_any",
+                                                any->to_llvm (), binary_op_type,
+                                                any->to_llvm (), any->to_llvm ());
   engine->addGlobalMapping (any_binary,
                             reinterpret_cast<void*>(&octave_jit_binary_any_any));
 
-  args.resize (2);
-  args[0] = any->to_llvm ();
-  args[1] = any->to_llvm ();
-
   binary_ops.resize (octave_value::num_binary_ops);
   for (int op = 0; op < octave_value::num_binary_ops; ++op)
     {
-      llvm::FunctionType *ftype = llvm::FunctionType::get (ov_t, args, false);
-
       llvm::Twine fn_name ("octave_jit_binary_any_any_");
       fn_name = fn_name + llvm::Twine (op);
-      llvm::Function *fn = llvm::Function::Create (ftype,
-                                                   llvm::Function::ExternalLinkage,
-                                                   fn_name, module);
+      fn = create_function (fn_name, any, any, any);
       llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn);
       builder.SetInsertPoint (block);
       llvm::APInt op_int(sizeof (octave_value::binary_op), op,
@@ -265,49 +255,18 @@
         binary_ops[op].add_overload (overload);
     }
 
-  // assign any = any
   llvm::Type *void_t = llvm::Type::getVoidTy (ctx);
-  args.resize (2);
-  args[0] = any->to_llvm ();
-  args[1] = any->to_llvm ();
-  llvm::FunctionType *ft = llvm::FunctionType::get (void_t, args, false);
-  llvm::Function *fn_help = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                                                    "octave_jit_assign_any_any_help",
-                                                    module);
-  engine->addGlobalMapping (fn_help,
-                            reinterpret_cast<void*>(&octave_jit_assign_any_any_help));
 
-  args.resize (2);
-  args[0] = any->to_llvm_arg ();
-  args[1] = any->to_llvm ();
-  ft = llvm::FunctionType::get (void_t, args, false);
-  llvm::Function *fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                                               "octave_jit_assign_any_any",
-                                               module);
-  fn->addFnAttr (llvm::Attribute::AlwaysInline);
-  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn);
-  builder.SetInsertPoint (body);
-  llvm::Value *value = builder.CreateLoad (fn->arg_begin ());
-  builder.CreateCall2 (fn_help, value, ++fn->arg_begin ());
-  builder.CreateStore (++fn->arg_begin (), fn->arg_begin ());
-  builder.CreateRetVoid ();
-  llvm::verifyFunction (*fn);
-  assign_fn.add_overload (fn, false, 0, any, any);
+  // grab any
+  fn = create_function ("octave_jit_grab_any", void_t, any->to_llvm ());
+                        
+  engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_grab_any));
+  grab_fn.add_overload (fn, false, 0, any);
 
-  // assign scalar = scalar
-  args.resize (2);
-  args[0] = scalar->to_llvm_arg ();
-  args[1] = scalar->to_llvm ();
-  ft = llvm::FunctionType::get (void_t, args, false);
-  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                               "octave_jit_assign_scalar_scalar", module);
-  fn->addFnAttr (llvm::Attribute::AlwaysInline);
-  body = llvm::BasicBlock::Create (ctx, "body", fn);
-  builder.SetInsertPoint (body);
-  builder.CreateStore (++fn->arg_begin (), fn->arg_begin ());
-  builder.CreateRetVoid ();
-  llvm::verifyFunction (*fn);
-  assign_fn.add_overload (fn, false, 0, scalar, scalar);
+  // release any
+  fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ());
+  engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any));
+  release_fn.add_overload (fn, false, 0, any);
 
   // now for binary scalar operations
   // FIXME: Finish all operations
@@ -320,19 +279,20 @@
   add_binary_op (scalar, octave_value::op_div, llvm::Instruction::FDiv);
   add_binary_op (scalar, octave_value::op_el_div, llvm::Instruction::FDiv);
 
+  add_binary_fcmp (scalar, octave_value::op_lt, llvm::CmpInst::FCMP_ULT);
+  add_binary_fcmp (scalar, octave_value::op_le, llvm::CmpInst::FCMP_ULE);
+  add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ);
+  add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE);
+  add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT);
+  add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE);
+
   // now for printing functions
   add_print (any, reinterpret_cast<void*> (&octave_jit_print_any));
   add_print (scalar, reinterpret_cast<void*> (&octave_jit_print_double));
 
   // bounds check for for loop
-  args.resize (2);
-  args[0] = range->to_llvm ();
-  args[1] = index->to_llvm ();
-  ft = llvm::FunctionType::get (bool_t, args, false);
-  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                               "octave_jit_simple_for_range", module);
-  fn->addFnAttr (llvm::Attribute::AlwaysInline);
-  body = llvm::BasicBlock::Create (ctx, "body", fn);
+  fn = create_function ("octave_jit_simple_for_range", boolean, range, index);
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn);
   builder.SetInsertPoint (body);
   {
     llvm::Value *nelem
@@ -346,12 +306,7 @@
   simple_for_check.add_overload (fn, false, boolean, range, index);
 
   // increment for for loop
-  args.resize (1);
-  args[0] = index->to_llvm ();
-  ft = llvm::FunctionType::get (index->to_llvm (), args, false);
-  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                               "octave_jit_imple_for_range_incr", module);
-  fn->addFnAttr (llvm::Attribute::AlwaysInline);
+  fn = create_function ("octave_jit_imple_for_range_incr", index, index);
   body = llvm::BasicBlock::Create (ctx, "body", fn);
   builder.SetInsertPoint (body);
   {
@@ -364,13 +319,7 @@
   simple_for_incr.add_overload (fn, false, index, index);
 
   // index variabe for for loop
-  args.resize (2);
-  args[0] = range->to_llvm ();
-  args[1] = index->to_llvm ();
-  ft = llvm::FunctionType::get (dbl, args, false);
-  fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                               "octave_jit_simple_for_idx", module);
-  fn->addFnAttr (llvm::Attribute::AlwaysInline);
+  fn = create_function ("octave_jit_simple_for_idx", scalar, range, index);
   body = llvm::BasicBlock::Create (ctx, "body", fn);
   builder.SetInsertPoint (body);
   {
@@ -386,61 +335,127 @@
   }
   llvm::verifyFunction (*fn);
   simple_for_index.add_overload (fn, false, scalar, range, index);
+
+  // logically true
+  // FIXME: Check for NaN
+  fn = create_function ("octave_logically_true_scalar", boolean, scalar);
+  body = llvm::BasicBlock::Create (ctx, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *zero = llvm::ConstantFP::get (scalar->to_llvm (), 0);
+    llvm::Value *ret = builder.CreateFCmpUNE (fn->arg_begin (), zero);
+    builder.CreateRet (ret);
+  }
+  llvm::verifyFunction (*fn);
+  logically_true.add_overload (fn, true, boolean, scalar);
+
+  fn = create_function ("octave_logically_true_bool", boolean, boolean);
+  body = llvm::BasicBlock::Create (ctx, "body", fn);
+  builder.SetInsertPoint (body);
+  builder.CreateRet (fn->arg_begin ());
+  llvm::verifyFunction (*fn);
+  logically_true.add_overload (fn, false, boolean, boolean);
 }
 
 void
 jit_typeinfo::add_print (jit_type *ty, void *call)
 {
-  llvm::LLVMContext& ctx = llvm::getGlobalContext ();
-  llvm::Type *void_t = llvm::Type::getVoidTy (ctx);
-  std::vector<llvm::Type *> args (2);
-  args[0] = llvm::Type::getInt8PtrTy (ctx);
-  args[1] = ty->to_llvm ();
-
   std::stringstream name;
   name << "octave_jit_print_" << ty->name ();
 
-  llvm::FunctionType *print_ty = llvm::FunctionType::get (void_t, args, false);
-  llvm::Function *fn = llvm::Function::Create (print_ty,
-                                               llvm::Function::ExternalLinkage,
-                                               name.str (), module);
+  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
+  llvm::Type *void_t = llvm::Type::getVoidTy (ctx);
+  llvm::Function *fn = create_function (name.str (), void_t,
+                                        llvm::Type::getInt8PtrTy (ctx),
+                                        ty->to_llvm ());
   engine->addGlobalMapping (fn, call);
 
   jit_function::overload ol (fn, false, 0, ty);
   print_fn.add_overload (ol);
 }
 
+// FIXME: cp between add_binary_op, add_binary_icmp, and add_binary_fcmp
 void
 jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op)
 {
-  llvm::LLVMContext& ctx = llvm::getGlobalContext ();
-  std::vector<llvm::Type *> args (2, ty->to_llvm ());
-  llvm::FunctionType *ft = llvm::FunctionType::get (ty->to_llvm (), args,
-                                                    false);
-
   std::stringstream fname;
   octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op);
   fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op)
         << "_" << ty->name ();
 
-  llvm::Function *fn = llvm::Function::Create (ft,
-                                               llvm::Function::ExternalLinkage,
-                                               fname.str (),
-                                               module);
-  fn->addFnAttr (llvm::Attribute::AlwaysInline);
+  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
+  llvm::Function *fn = create_function (fname.str (), ty, ty, ty);
   llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn);
-  llvm::IRBuilder<> fn_builder (block);
+  builder.SetInsertPoint (block);
   llvm::Instruction::BinaryOps temp
     = static_cast<llvm::Instruction::BinaryOps>(llvm_op);
-  llvm::Value *ret = fn_builder.CreateBinOp (temp, fn->arg_begin (),
-                                             ++fn->arg_begin ());
-  fn_builder.CreateRet (ret);
+  llvm::Value *ret = builder.CreateBinOp (temp, fn->arg_begin (),
+                                          ++fn->arg_begin ());
+  builder.CreateRet (ret);
   llvm::verifyFunction (*fn);
 
   jit_function::overload ol(fn, false, ty, ty, ty);
   binary_ops[op].add_overload (ol);
 }
 
+void
+jit_typeinfo::add_binary_icmp (jit_type *ty, int op, int llvm_op)
+{
+  std::stringstream fname;
+  octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op);
+  fname << "octave_jit" << octave_value::binary_op_as_string (ov_op)
+        << "_" << ty->name ();
+
+  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
+  llvm::Function *fn = create_function (fname.str (), boolean, ty, ty);
+  llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn);
+  builder.SetInsertPoint (block);
+  llvm::CmpInst::Predicate temp
+    = static_cast<llvm::CmpInst::Predicate>(llvm_op);
+  llvm::Value *ret = builder.CreateICmp (temp, fn->arg_begin (),
+                                         ++fn->arg_begin ());
+  builder.CreateRet (ret);
+  llvm::verifyFunction (*fn);
+
+  jit_function::overload ol (fn, false, boolean, ty, ty);
+  binary_ops[op].add_overload (ol);
+}
+
+void
+jit_typeinfo::add_binary_fcmp (jit_type *ty, int op, int llvm_op)
+{
+  std::stringstream fname;
+  octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op);
+  fname << "octave_jit" << octave_value::binary_op_as_string (ov_op)
+        << "_" << ty->name ();
+
+  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
+  llvm::Function *fn = create_function (fname.str (), boolean, ty, ty);
+  llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn);
+  builder.SetInsertPoint (block);
+  llvm::CmpInst::Predicate temp
+    = static_cast<llvm::CmpInst::Predicate>(llvm_op);
+  llvm::Value *ret = builder.CreateFCmp (temp, fn->arg_begin (),
+                                         ++fn->arg_begin ());
+  builder.CreateRet (ret);
+  llvm::verifyFunction (*fn);
+
+  jit_function::overload ol (fn, false, boolean, ty, ty);
+  binary_ops[op].add_overload (ol);
+}
+
+llvm::Function *
+jit_typeinfo::create_function (const llvm::Twine& name, llvm::Type *ret,
+                               const std::vector<llvm::Type *>& args)
+{
+  llvm::FunctionType *ft = llvm::FunctionType::get (ret, args, false);
+  llvm::Function *fn = llvm::Function::Create (ft,
+                                               llvm::Function::ExternalLinkage,
+                                               name, module);
+  fn->addFnAttr (llvm::Attribute::AlwaysInline);
+  return fn;
+}    
+
 jit_type*
 jit_typeinfo::type_of (const octave_value &ov) const
 {
@@ -459,17 +474,11 @@
 const jit_function&
 jit_typeinfo::binary_op (int op) const
 {
+  assert (static_cast<size_t>(op) < binary_ops.size ());
   return binary_ops[op];
 }
 
 const jit_function::overload&
-jit_typeinfo::assign_op (jit_type *lhs, jit_type *rhs) const
-{
-  assert (lhs == rhs);
-  return assign_fn.get_overload (lhs, rhs);
-}
-
-const jit_function::overload&
 jit_typeinfo::print_value (jit_type *to_print) const
 {
   return print_fn.get_overload (to_print);
@@ -577,6 +586,9 @@
   if (is_lvalue)
     fail ();
 
+  if (be.op_type () >= octave_value::num_binary_ops)
+    fail ();
+
   tree_expression *lhs = be.lhs ();
   lhs->accept (*this);
   jit_type *tlhs = type_stack.back ();
@@ -644,7 +656,30 @@
   jit_type *control_t = type_stack.back ();
   type_stack.pop_back ();
 
-  infer_simple_for (cmd, control_t);
+  // FIXME: We should improve type inference so we don't have to do this
+  // to generate nested for loop code
+
+  // quick hack, check if the for loop bounds are const. If we
+  // run at least one, we don't have to merge types
+  bool atleast_once = false;
+  if (control->is_constant ())
+    {
+      octave_value over = control->rvalue1 ();
+      if (over.is_range ())
+        {
+          Range rng = over.range_value ();
+          atleast_once = rng.nelem () > 0;
+        }
+    }
+
+  if (atleast_once)
+    infer_simple_for (cmd, control_t);
+  else
+    {
+      type_map fallthrough = types;
+      infer_simple_for (cmd, control_t);
+      merge (types, fallthrough);
+    }
 }
 
 void
@@ -697,15 +732,52 @@
 }
 
 void
-jit_infer::visit_if_command (tree_if_command&)
+jit_infer::visit_if_command (tree_if_command& cmd)
 {
-  fail ();
+  if (is_lvalue)
+    fail ();
+
+  tree_if_command_list *lst = cmd.cmd_list ();
+  assert (lst);
+  lst->accept (*this);
 }
 
 void
-jit_infer::visit_if_command_list (tree_if_command_list&)
+jit_infer::visit_if_command_list (tree_if_command_list& lst)
 {
-  fail ();
+  // determine the types on each branch of the if seperatly, then merge
+  type_map fallthrough = types, last;
+  bool first_time = true;
+  for (tree_if_command_list::iterator p = lst.begin (); p != lst.end(); ++p)
+    {
+      tree_if_clause *tic = *p;
+
+      if (! first_time)
+        types = fallthrough;
+
+      if (! tic->is_else_clause ())
+        {
+          tree_expression *expr = tic->condition ();
+          expr->accept (*this);
+        }
+
+      fallthrough = types;
+
+      tree_statement_list *stmt_lst = tic->commands ();
+      assert (stmt_lst);
+      stmt_lst->accept (*this);
+
+      if (first_time)
+        last = types;
+      else
+        merge (last, types);
+    }
+
+  types = last;
+
+  tree_if_clause *last_clause = lst.back ();
+  if (! last_clause->is_else_clause ())
+    merge (types, fallthrough);
 }
 
 void
@@ -971,6 +1043,27 @@
     type_stack.push_back (iter->second.second);
 }
 
+void
+jit_infer::merge (type_map& dest, const type_map& src)
+{
+  if (dest.size () != src.size ())
+    fail ();
+
+  type_map::iterator dest_iter;
+  type_map::const_iterator src_iter;
+  for (dest_iter = dest.begin (), src_iter = src.begin ();
+       dest_iter != dest.end (); ++dest_iter, ++src_iter)
+    {
+      if (dest_iter->first.name () != src_iter->first.name ()
+          || dest_iter->second.second != src_iter->second.second)
+        fail ();
+
+      // require argin if one path requires argin
+      dest_iter->second.first = dest_iter->second.first
+        || src_iter->second.first;
+    }
+}
+
 // -------------------- jit_generator --------------------
 jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod,
                               tree_simple_for_command& cmd, jit_type *bounds,
@@ -1148,11 +1241,24 @@
   std::string name = ti.name ();
   value variable = variables[name];
   if (is_lvalue)
-    value_stack.push_back (variable);
+    {
+      value_stack.push_back (variable);
+
+      const jit_function::overload& ol = tinfo->release (variable.first);
+      if (ol.function)
+        {
+          llvm::Value *load = builder.CreateLoad (variable.second, name);
+          builder.CreateCall (ol.function, load);
+        }
+    }
   else
     {
       llvm::Value *load = builder.CreateLoad (variable.second, name);
       push_value (variable.first, load);
+
+      const jit_function::overload& ol = tinfo->grab (variable.first);
+      if (ol.function)
+        builder.CreateCall (ol.function, load);
     }
 }
 
@@ -1163,15 +1269,71 @@
 }
 
 void
-jit_generator::visit_if_command (tree_if_command&)
+jit_generator::visit_if_command (tree_if_command& cmd)
 {
-  fail ();
+  tree_if_command_list *lst = cmd.cmd_list ();
+  assert (lst);
+  lst->accept (*this);
 }
 
 void
-jit_generator::visit_if_command_list (tree_if_command_list&)
+jit_generator::visit_if_command_list (tree_if_command_list& lst)
 {
-  fail ();
+  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
+  llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "if_tail", function);
+  std::vector<llvm::BasicBlock *> clause_entry (lst.size ());
+  tree_if_command_list::iterator p;
+  size_t i;
+  for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i)
+    {
+      tree_if_clause *tic = *p;
+      if (tic->is_else_clause ())
+        clause_entry[i] = llvm::BasicBlock::Create (ctx, "else_body", function,
+                                                    tail);
+      else
+        clause_entry[i] = llvm::BasicBlock::Create (ctx, "if_cond", function,
+                                                    tail);
+    }
+
+  builder.CreateBr (clause_entry[0]);
+
+  for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i)
+    {
+      tree_if_clause *tic = *p;
+      llvm::BasicBlock *body;
+      if (tic->is_else_clause ())
+        body = clause_entry[i];
+      else
+        {
+          llvm::BasicBlock *cond = clause_entry[i];
+          builder.SetInsertPoint (cond);
+
+          tree_expression *expr = tic->condition ();
+          expr->accept (*this);
+
+          // FIXME: Handle undefined case
+          value condv = value_stack.back ();
+          value_stack.pop_back ();
+
+          const jit_function::overload& ol = tinfo->get_logically_true (condv.first);
+          if (! ol.function)
+            fail ();
+
+          bool last = i + 1 == clause_entry.size ();
+          llvm::BasicBlock *next = last ? tail : clause_entry[i + 1];
+          body = llvm::BasicBlock::Create (ctx, "if_body", function, tail);
+
+          llvm::Value *is_true = builder.CreateCall (ol.function, condv.second);
+          builder.CreateCondBr (is_true, body, next);
+        }
+
+      tree_statement_list *stmt_lst = tic->commands ();
+      builder.SetInsertPoint (body);
+      stmt_lst->accept (*this);
+      builder.CreateBr (tail);
+    }
+
+  builder.SetInsertPoint (tail);
 }
 
 void
@@ -1278,25 +1440,24 @@
   if (is_lvalue)
     fail ();
 
-  // resolve lhs
-  is_lvalue = true;
-  tree_expression *lhs = tsa.left_hand_side ();
-  lhs->accept (*this);
-
-  value lhsv = value_stack.back ();
-  value_stack.pop_back ();
-
   // resolve rhs
-  is_lvalue = false;
   tree_expression *rhs = tsa.right_hand_side ();
   rhs->accept (*this);
 
   value rhsv = value_stack.back ();
   value_stack.pop_back ();
 
-  // do assign, then store rhs as the result
-  jit_function::overload ol = tinfo->assign_op (lhsv.first, rhsv.first);
-  builder.CreateCall2 (ol.function, lhsv.second, rhsv.second);
+  // resolve lhs
+  is_lvalue = true;
+  tree_expression *lhs = tsa.left_hand_side ();
+  lhs->accept (*this);
+  is_lvalue = false;
+
+  value lhsv = value_stack.back ();
+  value_stack.pop_back ();
+
+  // do assign, then keep rhs as the result
+  builder.CreateStore (rhsv.second, lhsv.second);
 
   if (tsa.print_result ())
     emit_print (lhs->name (), rhsv);
@@ -1449,9 +1610,7 @@
   llvm::Value *lindex = builder.CreateLoad (llvm_index);
   llvm::Value *llvm_iter = builder.CreateCall2 (index_ol.function, over.second, lindex);
   value iter(index_ol.result, llvm_iter);
-
-  jit_function::overload assign = tinfo->assign_op (lhsv.first, iter.first);
-  builder.CreateCall2 (assign.function, lhsv.second, iter.second);
+  builder.CreateStore (iter.second, lhsv.second);
 
   tree_statement_list *lst = cmd.body ();
   lst->accept (*this);
--- a/src/pt-jit.h	Fri May 18 10:24:14 2012 -0600
+++ b/src/pt-jit.h	Mon May 21 15:41:19 2012 -0600
@@ -52,11 +52,12 @@
 // endfor
 // Will compile. Nested for loops with constant bounds are also supported.
 //
+// If statements/comparisons compile, but && and || do not.
+//
 // TODO:
-// 1. Support if statements
-// 2. Support iteration over matricies
-// 3. Check error state
-// 4. ...
+// 1. Support iteration over matricies
+// 2. Check error state
+// 3. ...
 // ---------------------------------------------------------
 
 
@@ -74,6 +75,7 @@
   class LLVMContext;
   class Type;
   class GenericValue;
+  class Twine;
 }
 
 class octave_base_value;
@@ -267,7 +269,15 @@
       return ol.result;
     }
 
-  const jit_function::overload& assign_op (jit_type *lhs, jit_type *rhs) const;
+  const jit_function::overload& grab (jit_type *ty) const
+  {
+    return grab_fn.get_overload (ty);
+  }
+
+  const jit_function::overload& release (jit_type *ty) const
+  {
+    return release_fn.get_overload (ty);
+  }
 
   const jit_function::overload& print_value (jit_type *to_print) const;
 
@@ -287,6 +297,11 @@
     return ol.result;
   }
 
+  const jit_function::overload& get_logically_true (jit_type *conv) const
+  {
+    return logically_true.get_overload (conv);
+  }
+
   void to_generic (jit_type *type, llvm::GenericValue& gv);
 
   void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov);
@@ -302,6 +317,61 @@
 
   void add_binary_op (jit_type *ty, int op, int llvm_op);
 
+  void add_binary_icmp (jit_type *ty, int op, int llvm_op);
+
+  void add_binary_fcmp (jit_type *ty, int op, int llvm_op);
+
+  llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret,
+                                   llvm::Type *arg0)
+  {
+    std::vector<llvm::Type *> args (1, arg0);
+    return create_function (name, ret, args);
+  }
+
+  llvm::Function *create_function (const llvm::Twine& name, jit_type *ret,
+                                   jit_type *arg0)
+  {
+    return create_function (name, ret->to_llvm (), arg0->to_llvm ());
+  }
+
+  llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret,
+                                   llvm::Type *arg0, llvm::Type *arg1)
+  {
+    std::vector<llvm::Type *> args (2);
+    args[0] = arg0;
+    args[1] = arg1;
+    return create_function (name, ret, args);
+  }
+
+  llvm::Function *create_function (const llvm::Twine& name, jit_type *ret,
+                                   jit_type *arg0, jit_type *arg1)
+  {
+    return create_function (name, ret->to_llvm (), arg0->to_llvm (),
+                            arg1->to_llvm ());
+  }
+
+  llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret,
+                                   llvm::Type *arg0, llvm::Type *arg1,
+                                   llvm::Type *arg2)
+  {
+    std::vector<llvm::Type *> args (3);
+    args[0] = arg0;
+    args[1] = arg1;
+    args[2] = arg2;
+    return create_function (name, ret, args);
+  }
+
+  llvm::Function *create_function (const llvm::Twine& name, jit_type *ret,
+                                   jit_type *arg0, jit_type *arg1,
+                                   jit_type *arg2)
+  {
+    return create_function (name, ret->to_llvm (), arg0->to_llvm (),
+                            arg1->to_llvm (), arg2->to_llvm ());
+  }
+
+  llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret,
+                                   const std::vector<llvm::Type *>& args);
+
   llvm::Module *module;
   llvm::ExecutionEngine *engine;
   int next_id;
@@ -316,11 +386,13 @@
   jit_type *index;
 
   std::vector<jit_function> binary_ops;
-  jit_function assign_fn;
+  jit_function grab_fn;
+  jit_function release_fn;
   jit_function print_fn;
   jit_function simple_for_check;
   jit_function simple_for_incr;
   jit_function simple_for_index;
+  jit_function logically_true;
 
   std::list<double> scalar_out;
   std::list<octave_base_value *> ov_out;
@@ -435,6 +507,8 @@
 
   void handle_identifier (const symbol_table::symbol_record_ref& record);
 
+  void merge (type_map& dest, const type_map& src);
+
   jit_typeinfo *tinfo;
 
   bool is_lvalue;