changeset 15592:6fc163b59746

Correctly handle break/continue for some corner cases in JIT * pt-jit.cc (jit_break_exception): New class. (jit_convert::jit_convert, jit_convert::visit_simple_for_command, jit_convert::visit_if_command_list, jit_convert::visit_while_command): Handle breaks/continues correctly. (jit_convert::visit_break_command, jit_convert::visit_continue_command): Throw instead of setting breaking. (jit_convert::visit_statement_list): Do not check breaking. (jit_convert::initialize): Do not initialize breaking. * pt-jit.h (jit_convert::breaking): Remove variable.
author Max Brister <max@2bass.com>
date Sun, 04 Nov 2012 15:38:48 -0700
parents 8be22193532b
children 24bbd2efea12
files libinterp/interp-core/pt-jit.cc libinterp/interp-core/pt-jit.h
diffstat 2 files changed, 119 insertions(+), 45 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/interp-core/pt-jit.cc	Sun Nov 04 11:23:38 2012 -0600
+++ b/libinterp/interp-core/pt-jit.cc	Sun Nov 04 15:38:48 2012 -0700
@@ -61,6 +61,12 @@
 
 static llvm::LLVMContext& context = llvm::getGlobalContext ();
 
+// -------------------- jit_break_exception --------------------
+
+// jit_break is thrown whenever a branch we are converting has only breaks or
+// continues. This is because all code that follows a break or continue is dead.
+class jit_break_exception : public std::exception {};
+
 // -------------------- jit_convert --------------------
 jit_convert::jit_convert (tree &tee, jit_type *for_bounds)
   : converting_function (false)
@@ -70,10 +76,14 @@
   if (for_bounds)
     create_variable (next_for_bounds (false), for_bounds);
 
-  visit (tee);
+  try
+    {
+      visit (tee);
+    }
+  catch (const jit_break_exception&)
+    {}
 
   // breaks must have been handled by the top level loop
-  assert (! breaking);
   assert (breaks.empty ());
   assert (continues.empty ());
 
@@ -120,32 +130,48 @@
     }
 
   jit_value *return_value = 0;
+  bool all_breaking = false;
   if (fcn.is_special_expr ())
     {
       tree_expression *expr = fcn.special_expr ();
       if (expr)
         {
           jit_variable *retvar = get_variable ("#return");
-          jit_value *retval = visit (expr);
+          jit_value *retval;
+          try
+            {
+              retval = visit (expr);
+            }
+          catch (const jit_break_exception&)
+            {}
+
+          if (breaks.size () || continues.size ())
+            throw jit_fail_exception ("break/continue not supported in "
+                                      "anonymous functions");
+
           block->append (factory.create<jit_assign> (retvar, retval));
           return_value = retvar;
         }
     }
   else
-    visit_statement_list (*fcn.body ());
-
-  // the user may use break or continue to exit the function. Because the
-  // function does not start as a loop, we can have one continue, one break, or
-  // a regular fallthrough to exit the function
-  if (continues.size ())
     {
-      assert (! continues.size ());
+      try
+        {
+          visit_statement_list (*fcn.body ());
+        }
+      catch (const jit_break_exception&)
+        {
+          all_breaking = true;
+        }
+
+      // the user may use break or continue to exit the function
       finish_breaks (final_block, continues);
+      finish_breaks (final_block, breaks);
     }
-  else if (breaks.size ())
-    finish_breaks (final_block, breaks);
-  else
+
+  if (! all_breaking)
     block->append (factory.create<jit_branch> (final_block));
+
   blocks.push_back (final_block);
   block = final_block;
 
@@ -251,7 +277,7 @@
 jit_convert::visit_break_command (tree_break_command&)
 {
   breaks.push_back (block);
-  breaking = true;
+  throw jit_break_exception ();
 }
 
 void
@@ -276,7 +302,7 @@
 jit_convert::visit_continue_command (tree_continue_command&)
 {
   continues.push_back (block);
-  breaking = true;
+  throw jit_break_exception ();
 }
 
 void
@@ -311,11 +337,9 @@
   // and used only inside the for loop (e.g. the index variable)
 
   // If we are a nested for loop we need to store the previous breaks
-  assert (! breaking);
   unwind_protect prot;
   prot.protect_var (breaks);
   prot.protect_var (continues);
-  prot.protect_var (breaking);
   breaks.clear ();
   continues.clear ();
 
@@ -354,23 +378,31 @@
 
   // do loop
   tree_statement_list *pt_body = cmd.body ();
-  pt_body->accept (*this);
-
-  if (breaking && continues.empty ())
+  bool all_breaking = false;
+  try
+    {
+      pt_body->accept (*this);
+    }
+  catch (const jit_break_exception&)
     {
-      // WTF are you doing user? Every branch was a continue, why did you have
-      // a loop??? Users are silly people...
-      finish_breaks (tail, breaks);
-      blocks.push_back (tail);
-      block = tail;
-      return;
+      if (continues.empty ())
+        {
+          // WTF are you doing user? Every branch was a break, why did you have
+          // a loop??? Users are silly people...
+          finish_breaks (tail, breaks);
+          blocks.push_back (tail);
+          block = tail;
+          return;
+        }
+
+      all_breaking = true;
     }
 
   // check our condition, continues jump to this block
   jit_block *check_block = factory.create<jit_block> ("for_check");
   blocks.push_back (check_block);
 
-  if (! breaking)
+  if (! all_breaking)
     block->append (factory.create<jit_branch> (check_block));
   finish_breaks (check_block, continues);
 
@@ -380,8 +412,8 @@
   jit_call *iter_inc = factory.create<jit_call> (add_fn, iterator, one);
   block->append (iter_inc);
   block->append (factory.create<jit_assign> (iterator, iter_inc));
-  check = block->append (factory.create<jit_call> (jit_typeinfo::for_check, control,
-                                           iterator));
+  check = block->append (factory.create<jit_call> (jit_typeinfo::for_check,
+                                                   control, iterator));
   block->append (factory.create<jit_cond_branch> (check, body, tail));
 
   // breaks will go to our tail
@@ -487,6 +519,13 @@
   if (! last_else)
     entry_blocks[entry_blocks.size () - 1] = tail;
 
+
+  // each branch in the if statement will have different breaks/continues
+  block_list current_breaks = breaks;
+  block_list current_continues = continues;
+  breaks.clear ();
+  continues.clear ();
+
   size_t num_incomming = 0; // number of incomming blocks to our tail
   iter = lst.begin ();
   for (size_t i = 0; iter != lst.end (); ++iter, ++i)
@@ -516,17 +555,23 @@
 
       tree_statement_list *stmt_lst = tic->commands ();
       assert (stmt_lst); // jwe: Can this be null?
-      stmt_lst->accept (*this);
-
-      if (breaking)
-        breaking = false;
-      else
+
+      try
         {
+          stmt_lst->accept (*this);
           ++num_incomming;
           block->append (factory.create<jit_branch> (tail));
         }
+      catch(const jit_break_exception&)
+        {}
+
+      current_breaks.splice (current_breaks.end (), breaks);
+      current_continues.splice (current_continues.end (), continues);
     }
 
+  breaks.splice (breaks.end (), current_breaks);
+  continues.splice (continues.end (), current_continues);
+
   if (num_incomming || ! last_else)
     {
       blocks.push_back (tail);
@@ -534,7 +579,7 @@
     }
   else
     // every branch broke, so we don't have a tail
-    breaking = true;
+    throw jit_break_exception ();
 }
 
 void
@@ -714,9 +759,6 @@
       // jwe: Can this ever be null?
       assert (elt);
       elt->accept (*this);
-
-      if (breaking)
-        break;
     }
 }
 
@@ -753,11 +795,9 @@
 void
 jit_convert::visit_while_command (tree_while_command& wc)
 {
-  assert (! breaking);
   unwind_protect prot;
   prot.protect_var (breaks);
   prot.protect_var (continues);
-  prot.protect_var (breaking);
   breaks.clear ();
   continues.clear ();
 
@@ -779,13 +819,23 @@
   block = body;
 
   tree_statement_list *loop_body = wc.body ();
+  bool all_breaking = false;
   if (loop_body)
-    loop_body->accept (*this);
+    {
+      try
+        {
+          loop_body->accept (*this);
+        }
+      catch (const jit_break_exception&)
+        {
+          all_breaking = true;
+        }
+    }
 
   finish_breaks (tail, breaks);
   finish_breaks (cond_check, continues);
 
-  if (! breaking)
+  if (! all_breaking)
     block->append (factory.create<jit_branch> (cond_check));
 
   blocks.push_back (tail);
@@ -805,7 +855,6 @@
   iterator_count = 0;
   for_bounds_count = 0;
   short_count = 0;
-  breaking = false;
   jit_instruction::reset_ids ();
 
   entry_block = factory.create<jit_block> ("body");
@@ -2239,6 +2288,33 @@
 Test some simple cases that compile.
 
 %!test
+%! for i=1:1e6
+%!   if i < 5
+%!     break
+%!   else
+%!     break
+%!   endif
+%! endfor
+%! assert (i, 1);
+
+%!test
+%! while 1
+%!   if 1
+%!     break
+%!  else
+%!    break
+%!  endif
+%! endwhile
+
+%!test
+%! for i=1:1e6
+%!   if i == 100
+%!     break
+%!   endif
+%! endfor
+%! assert (i, 100);
+
+%!test
 %! inc = 1e-5;
 %! result = 0;
 %! for ii = 0:inc:1
--- a/libinterp/interp-core/pt-jit.h	Sun Nov 04 11:23:38 2012 -0600
+++ b/libinterp/interp-core/pt-jit.h	Sun Nov 04 15:38:48 2012 -0700
@@ -231,8 +231,6 @@
 
   jit_value *visit (tree& tee);
 
-  bool breaking; // true if we are breaking OR continuing
-
   typedef std::list<jit_block *> block_list;
   block_list breaks;
   block_list continues;