# HG changeset patch # User Max Brister # Date 1351896776 21600 # Node ID 0754bdfbc8fecbbf3b6111a8ee04562896b43815 # Parent 52df2e7baabe46d2a365eb629b06501f0148112d Correct multiplication complex multiplication with NaN in JIT * jit-typeinfo.cc (jit_function::call): Remove dead code. (jit_typeinfo::jit_typeinfo): Fix complex multiplication. (jit_typeinfo::do_type_of): Do not treat complex numbers with 0 imag as complex. * pt-jit.cc (jit_convert::visit_constant): Use jit_typeinfo::type_of. diff -r 52df2e7baabe -r 0754bdfbc8fe libinterp/interp-core/jit-typeinfo.cc --- a/libinterp/interp-core/jit-typeinfo.cc Fri Nov 02 14:32:22 2012 -0600 +++ b/libinterp/interp-core/jit-typeinfo.cc Fri Nov 02 16:52:56 2012 -0600 @@ -633,8 +633,6 @@ throw jit_fail_exception ("Call not implemented"); assert (in_args.size () == args.size ()); - llvm::Function *stacksave - = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); llvm::SmallVector llvm_args; llvm_args.reserve (in_args.size () + sret ()); @@ -1322,12 +1320,29 @@ llvm::Value *one = builder.getInt32 (1); llvm::Value *two = builder.getInt32 (2); llvm::Value *three = builder.getInt32 (3); + llvm::Value *fzero = llvm::ConstantFP::get (scalar_t, 0); + + // we are really dealing with a complex number OR a scalar. That is, if the + // complex component is 0, we really have a scalar. This matters in + // 0+0i * NaN + llvm::BasicBlock *complex_mul = fn.new_block ("complex_mul"); + llvm::BasicBlock *real_mul = fn.new_block ("real_mul"); + llvm::BasicBlock *ret_block = fn.new_block ("ret"); + llvm::Value *temp = builder.CreateFCmpUEQ (complex_imag (lhs), fzero); + llvm::Value *temp2 = builder.CreateFCmpUEQ (complex_imag (rhs), fzero); + temp = builder.CreateAnd (temp, temp2); + builder.CreateCondBr (temp, real_mul, complex_mul); + + builder.SetInsertPoint(real_mul); + temp = builder.CreateFMul (complex_real (lhs), complex_real (rhs)); + llvm::Value *real_branch_ret = complex_new (temp, fzero); + builder.CreateBr (ret_block); llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4); llvm::Value *mlhs = llvm::UndefValue::get (vec4); llvm::Value *mrhs = mlhs; - - llvm::Value *temp = complex_real (lhs); + builder.SetInsertPoint (complex_mul); + temp = complex_real (lhs); mlhs = builder.CreateInsertElement (mlhs, temp, zero); mlhs = builder.CreateInsertElement (mlhs, temp, two); temp = complex_imag (lhs); @@ -1349,7 +1364,15 @@ tlhs = builder.CreateExtractElement (mres, two); trhs = builder.CreateExtractElement (mres, three); llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs); - fn.do_return (builder, complex_new (ret_real, ret_imag)); + llvm::Value *complex_branch_ret = complex_new (ret_real, ret_imag); + builder.CreateBr (ret_block); + + builder.SetInsertPoint (ret_block); + llvm::PHINode *merge = llvm::PHINode::Create(complex_t, 2); + builder.Insert (merge); + merge->addIncoming (real_branch_ret, real_mul); + merge->addIncoming (complex_branch_ret, complex_mul); + fn.do_return (builder, merge); } binary_ops[octave_value::op_mul].add_overload (fn); @@ -1381,10 +1404,25 @@ body = fn.new_block (); builder.SetInsertPoint (body); { + llvm::BasicBlock *complex_mul = fn.new_block ("complex_mul"); + llvm::BasicBlock *scalar_mul = fn.new_block ("scalar_mul"); + + llvm::Value *fzero = llvm::ConstantFP::get (scalar_t, 0); llvm::Value *lhs = fn.argument (builder, 0); - llvm::Value *tlhs = complex_new (lhs, lhs); llvm::Value *rhs = fn.argument (builder, 1); - fn.do_return (builder, builder.CreateFMul (tlhs, rhs)); + + llvm::Value *cmp = builder.CreateFCmpUEQ (complex_imag (rhs), fzero); + builder.CreateCondBr (cmp, scalar_mul, complex_mul); + + builder.SetInsertPoint (scalar_mul); + llvm::Value *temp = complex_real (rhs); + temp = builder.CreateFMul (lhs, temp); + fn.do_return (builder, complex_new (temp, fzero), false); + + + builder.SetInsertPoint (complex_mul); + temp = complex_new (lhs, lhs); + fn.do_return (builder, builder.CreateFMul (temp, rhs)); } binary_ops[octave_value::op_mul].add_overload (fn); binary_ops[octave_value::op_el_mul].add_overload (fn); @@ -2273,7 +2311,14 @@ } if (ov.is_complex_scalar ()) - return get_complex (); + { + Complex cv = ov.complex_value (); + + // We don't really represent complex values, instead we represent + // complex_or_scalar. If the imag value is zero, we assume a scalar. + if (cv.imag () == 0) + return get_complex (); + } return get_any (); } diff -r 52df2e7baabe -r 0754bdfbc8fe libinterp/interp-core/pt-jit.cc --- a/libinterp/interp-core/pt-jit.cc Fri Nov 02 14:32:22 2012 -0600 +++ b/libinterp/interp-core/pt-jit.cc Fri Nov 02 16:52:56 2012 -0600 @@ -571,17 +571,19 @@ jit_convert::visit_constant (tree_constant& tc) { octave_value v = tc.rvalue1 (); - if (v.is_real_scalar () && v.is_double_type () && ! v.is_complex_type ()) + jit_type *ty = jit_typeinfo::type_of (v); + + if (ty == jit_typeinfo::get_scalar ()) { double dv = v.double_value (); result = factory.create (dv); } - else if (v.is_range ()) + else if (ty == jit_typeinfo::get_range ()) { Range rv = v.range_value (); result = factory.create (rv); } - else if (v.is_complex_scalar ()) + else if (ty == jit_typeinfo::get_complex ()) { Complex cv = v.complex_value (); result = factory.create (cv); @@ -2254,6 +2256,35 @@ %! assert (abs (result - 1/9) < 1e-5); %!test +%! temp = 1+1i; +%! nan = NaN; +%! while 1 +%! temp = temp - 1i; +%! temp = temp * nan; +%! break; +%! endwhile +%! assert (imag (temp), 0); + +%!test +%! temp = 1+1i; +%! nan = NaN+1i; +%! while 1 +%! nan = nan - 1i; +%! temp = temp - 1i; +%! temp = temp * nan; +%! break; +%! endwhile +%! assert (imag (temp), 0); + +%!test +%! temp = 1+1i; +%! while 1 +%! temp = temp * 5; +%! break; +%! endwhile +%! assert (temp, 5+5i); + +%!test %! nr = 1001; %! mat = zeros (1, nr); %! for i = 1:nr