changeset 15583:0754bdfbc8fe

Correct multiplication complex multiplication with NaN in JIT * jit-typeinfo.cc (jit_function::call): Remove dead code. (jit_typeinfo::jit_typeinfo): Fix complex multiplication. (jit_typeinfo::do_type_of): Do not treat complex numbers with 0 imag as complex. * pt-jit.cc (jit_convert::visit_constant): Use jit_typeinfo::type_of.
author Max Brister <max@2bass.com>
date Fri, 02 Nov 2012 16:52:56 -0600
parents 52df2e7baabe
children ae0af6c664c4
files libinterp/interp-core/jit-typeinfo.cc libinterp/interp-core/pt-jit.cc
diffstat 2 files changed, 87 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/interp-core/jit-typeinfo.cc	Fri Nov 02 14:32:22 2012 -0600
+++ b/libinterp/interp-core/jit-typeinfo.cc	Fri Nov 02 16:52:56 2012 -0600
@@ -633,8 +633,6 @@
     throw jit_fail_exception ("Call not implemented");
 
   assert (in_args.size () == args.size ());
-  llvm::Function *stacksave
-    = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave);
   llvm::SmallVector<llvm::Value *, 10> llvm_args;
   llvm_args.reserve (in_args.size () + sret ());
 
@@ -1322,12 +1320,29 @@
     llvm::Value *one = builder.getInt32 (1);
     llvm::Value *two = builder.getInt32 (2);
     llvm::Value *three = builder.getInt32 (3);
+    llvm::Value *fzero = llvm::ConstantFP::get (scalar_t, 0);
+
+    // we are really dealing with a complex number OR a scalar. That is, if the
+    // complex component is 0, we really have a scalar. This matters in
+    // 0+0i * NaN
+    llvm::BasicBlock *complex_mul = fn.new_block ("complex_mul");
+    llvm::BasicBlock *real_mul = fn.new_block ("real_mul");
+    llvm::BasicBlock *ret_block = fn.new_block ("ret");
+    llvm::Value *temp = builder.CreateFCmpUEQ (complex_imag (lhs), fzero);
+    llvm::Value *temp2 = builder.CreateFCmpUEQ (complex_imag (rhs), fzero);
+    temp = builder.CreateAnd (temp, temp2);
+    builder.CreateCondBr (temp, real_mul, complex_mul);
+
+    builder.SetInsertPoint(real_mul);
+    temp = builder.CreateFMul (complex_real (lhs), complex_real (rhs));
+    llvm::Value *real_branch_ret = complex_new (temp, fzero);
+    builder.CreateBr (ret_block);
 
     llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4);
     llvm::Value *mlhs = llvm::UndefValue::get (vec4);
     llvm::Value *mrhs = mlhs;
-
-    llvm::Value *temp = complex_real (lhs);
+    builder.SetInsertPoint (complex_mul);
+    temp = complex_real (lhs);
     mlhs = builder.CreateInsertElement (mlhs, temp, zero);
     mlhs = builder.CreateInsertElement (mlhs, temp, two);
     temp = complex_imag (lhs);
@@ -1349,7 +1364,15 @@
     tlhs = builder.CreateExtractElement (mres, two);
     trhs = builder.CreateExtractElement (mres, three);
     llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs);
-    fn.do_return (builder, complex_new (ret_real, ret_imag));
+    llvm::Value *complex_branch_ret = complex_new (ret_real, ret_imag);
+    builder.CreateBr (ret_block);
+
+    builder.SetInsertPoint (ret_block);
+    llvm::PHINode *merge = llvm::PHINode::Create(complex_t, 2);
+    builder.Insert (merge);
+    merge->addIncoming (real_branch_ret, real_mul);
+    merge->addIncoming (complex_branch_ret, complex_mul);
+    fn.do_return (builder, merge);
   }
 
   binary_ops[octave_value::op_mul].add_overload (fn);
@@ -1381,10 +1404,25 @@
   body = fn.new_block ();
   builder.SetInsertPoint (body);
   {
+    llvm::BasicBlock *complex_mul = fn.new_block ("complex_mul");
+    llvm::BasicBlock *scalar_mul = fn.new_block ("scalar_mul");
+
+    llvm::Value *fzero = llvm::ConstantFP::get (scalar_t, 0);
     llvm::Value *lhs = fn.argument (builder, 0);
-    llvm::Value *tlhs = complex_new (lhs, lhs);
     llvm::Value *rhs = fn.argument (builder, 1);
-    fn.do_return (builder, builder.CreateFMul (tlhs, rhs));
+
+    llvm::Value *cmp = builder.CreateFCmpUEQ (complex_imag (rhs), fzero);
+    builder.CreateCondBr (cmp, scalar_mul, complex_mul);
+
+    builder.SetInsertPoint (scalar_mul);
+    llvm::Value *temp = complex_real (rhs);
+    temp = builder.CreateFMul (lhs, temp);
+    fn.do_return (builder, complex_new (temp, fzero), false);
+
+
+    builder.SetInsertPoint (complex_mul);
+    temp = complex_new (lhs, lhs);
+    fn.do_return (builder, builder.CreateFMul (temp, rhs));
   }
   binary_ops[octave_value::op_mul].add_overload (fn);
   binary_ops[octave_value::op_el_mul].add_overload (fn);
@@ -2273,7 +2311,14 @@
     }
 
   if (ov.is_complex_scalar ())
-    return get_complex ();
+    {
+      Complex cv = ov.complex_value ();
+
+      // We don't really represent complex values, instead we represent
+      // complex_or_scalar. If the imag value is zero, we assume a scalar.
+      if (cv.imag () == 0)
+        return get_complex ();
+    }
 
   return get_any ();
 }
--- a/libinterp/interp-core/pt-jit.cc	Fri Nov 02 14:32:22 2012 -0600
+++ b/libinterp/interp-core/pt-jit.cc	Fri Nov 02 16:52:56 2012 -0600
@@ -571,17 +571,19 @@
 jit_convert::visit_constant (tree_constant& tc)
 {
   octave_value v = tc.rvalue1 ();
-  if (v.is_real_scalar () && v.is_double_type () && ! v.is_complex_type ())
+  jit_type *ty = jit_typeinfo::type_of (v);
+
+  if (ty == jit_typeinfo::get_scalar ())
     {
       double dv = v.double_value ();
       result = factory.create<jit_const_scalar> (dv);
     }
-  else if (v.is_range ())
+  else if (ty == jit_typeinfo::get_range ())
     {
       Range rv = v.range_value ();
       result = factory.create<jit_const_range> (rv);
     }
-  else if (v.is_complex_scalar ())
+  else if (ty == jit_typeinfo::get_complex ())
     {
       Complex cv = v.complex_value ();
       result = factory.create<jit_const_complex> (cv);
@@ -2254,6 +2256,35 @@
 %! assert (abs (result - 1/9) < 1e-5);
 
 %!test
+%! temp = 1+1i;
+%! nan = NaN;
+%! while 1
+%!   temp = temp - 1i;
+%!   temp = temp * nan;
+%!   break;
+%! endwhile
+%! assert (imag (temp), 0);
+
+%!test
+%! temp = 1+1i;
+%! nan = NaN+1i;
+%! while 1
+%!   nan = nan - 1i;
+%!   temp = temp - 1i;
+%!   temp = temp * nan;
+%!   break;
+%! endwhile
+%! assert (imag (temp), 0);
+
+%!test
+%! temp = 1+1i;
+%! while 1
+%!   temp = temp * 5;
+%!   break;
+%! endwhile
+%! assert (temp, 5+5i);
+
+%!test
 %! nr = 1001;
 %! mat = zeros (1, nr);
 %! for i = 1:nr