diff src/pt-jit.cc @ 14978:f649b66ef1af

Add short circult operators to JIT * src/pt-jit.cc (jit_convert::jit_convert): Initialize short_count. (jit_convert::visit_binary_expression): Add support for short circut operators. (jit_convert::visit_if_command_list): Remove duplicate check append. (jit_convert::visit_simple_assignment): Store result. (jit_convert::convert_llvm::visit): New overload. * src/pt-jit.h (jit_const_bool): New specialization of jit_const. (jit_convert::short_count): New variable.
author Max Brister <max@2bass.com>
date Wed, 27 Jun 2012 23:43:06 -0500
parents d3f9801b1f29
children a5f75de0dab1
line wrap: on
line diff
--- a/src/pt-jit.cc	Wed Jun 27 18:50:59 2012 -0500
+++ b/src/pt-jit.cc	Wed Jun 27 23:43:06 2012 -0500
@@ -1836,7 +1836,7 @@
 
 // -------------------- jit_convert --------------------
 jit_convert::jit_convert (llvm::Module *module, tree &tee)
-  : iterator_count (0), breaking (false)
+  : iterator_count (0), short_count (0), breaking (false)
 {
   jit_instruction::reset_ids ();
 
@@ -1942,18 +1942,65 @@
 void
 jit_convert::visit_binary_expression (tree_binary_expression& be)
 {
-  // this is the case for bool_or and bool_and
   if (be.op_type () >= octave_value::num_binary_ops)
-    fail ("Unsupported binary operator");
-
-  tree_expression *lhs = be.lhs ();
-  jit_value *lhsv = visit (lhs);
-
-  tree_expression *rhs = be.rhs ();
-  jit_value *rhsv = visit (rhs);
-
-  const jit_function& fn = jit_typeinfo::binary_op (be.op_type ());
-  result = create_checked (fn, lhsv, rhsv);
+    {
+      tree_boolean_expression *boole;
+      boole = dynamic_cast<tree_boolean_expression *> (&be);
+      assert (boole);
+      bool is_and = boole->op_type () == tree_boolean_expression::bool_and;
+
+      std::stringstream ss;
+      ss << "#short_result" << short_count++;
+
+      std::string short_name = ss.str ();
+      jit_variable *short_result = create<jit_variable> (short_name);
+      vmap[short_name] = short_result;
+
+      jit_block *done = create<jit_block> (block->name ());
+      tree_expression *lhs = be.lhs ();
+      jit_value *lhsv = visit (lhs);
+      lhsv = create_checked (&jit_typeinfo::logically_true, lhsv);
+
+      jit_block *short_early = create<jit_block> ("short_early");
+      append (short_early);
+
+      jit_block *short_cont = create<jit_block> ("short_cont");
+
+      if (is_and)
+        block->append (create<jit_cond_branch> (lhsv, short_cont, short_early));
+      else
+        block->append (create<jit_cond_branch> (lhsv, short_early, short_cont));
+
+      block = short_early;
+
+      jit_value *early_result = create<jit_const_bool> (! is_and);
+      block->append (create<jit_assign> (short_result, early_result));
+      block->append (create<jit_branch> (done));
+
+      append (short_cont);
+      block = short_cont;
+
+      tree_expression *rhs = be.rhs ();
+      jit_value *rhsv = visit (rhs);
+      rhsv = create_checked (&jit_typeinfo::logically_true, rhsv);
+      block->append (create<jit_assign> (short_result, rhsv));
+      block->append (create<jit_branch> (done));
+
+      append (done);
+      block = done;
+      result = short_result;
+    }
+  else
+    {
+      tree_expression *lhs = be.lhs ();
+      jit_value *lhsv = visit (lhs);
+
+      tree_expression *rhs = be.rhs ();
+      jit_value *rhsv = visit (rhs);
+
+      const jit_function& fn = jit_typeinfo::binary_op (be.op_type ());
+      result = create_checked (fn, lhsv, rhsv);
+    }
 }
 
 void
@@ -2196,8 +2243,6 @@
           jit_value *cond = visit (expr);
           jit_call *check = create_checked (&jit_typeinfo::logically_true,
                                             cond);
-          block->append (check);
-
           jit_block *body = create<jit_block> (i == 0 ? "if_body"
                                                : "ifelse_body");
           append (body);
@@ -2329,7 +2374,7 @@
   tree_expression *rhs = tsa.right_hand_side ();
   jit_value *rhsv = visit (rhs);
 
-  do_assign (tsa.left_hand_side (), rhsv);
+  result = do_assign (tsa.left_hand_side (), rhsv);
 }
 
 void
@@ -2980,6 +3025,12 @@
 }
 
 void
+jit_convert::convert_llvm::visit (jit_const_bool& cb)
+{
+  cb.stash_llvm (llvm::ConstantInt::get (cb.type_llvm (), cb.value ()));
+}
+
+void
 jit_convert::convert_llvm::visit (jit_const_scalar& cs)
 {
   cs.stash_llvm (llvm::ConstantFP::get (cs.type_llvm (), cs.value ()));