Mercurial > octave-nkf
diff libinterp/interp-core/jit-typeinfo.cc @ 15583:0754bdfbc8fe
Correct multiplication complex multiplication with NaN in JIT
* jit-typeinfo.cc (jit_function::call): Remove dead code.
(jit_typeinfo::jit_typeinfo): Fix complex multiplication.
(jit_typeinfo::do_type_of): Do not treat complex numbers with 0 imag as complex.
* pt-jit.cc (jit_convert::visit_constant): Use jit_typeinfo::type_of.
author | Max Brister <max@2bass.com> |
---|---|
date | Fri, 02 Nov 2012 16:52:56 -0600 |
parents | 8ccb187b24e9 |
children | 44272909d926 |
line wrap: on
line diff
--- a/libinterp/interp-core/jit-typeinfo.cc Fri Nov 02 14:32:22 2012 -0600 +++ b/libinterp/interp-core/jit-typeinfo.cc Fri Nov 02 16:52:56 2012 -0600 @@ -633,8 +633,6 @@ throw jit_fail_exception ("Call not implemented"); assert (in_args.size () == args.size ()); - llvm::Function *stacksave - = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); llvm::SmallVector<llvm::Value *, 10> llvm_args; llvm_args.reserve (in_args.size () + sret ()); @@ -1322,12 +1320,29 @@ llvm::Value *one = builder.getInt32 (1); llvm::Value *two = builder.getInt32 (2); llvm::Value *three = builder.getInt32 (3); + llvm::Value *fzero = llvm::ConstantFP::get (scalar_t, 0); + + // we are really dealing with a complex number OR a scalar. That is, if the + // complex component is 0, we really have a scalar. This matters in + // 0+0i * NaN + llvm::BasicBlock *complex_mul = fn.new_block ("complex_mul"); + llvm::BasicBlock *real_mul = fn.new_block ("real_mul"); + llvm::BasicBlock *ret_block = fn.new_block ("ret"); + llvm::Value *temp = builder.CreateFCmpUEQ (complex_imag (lhs), fzero); + llvm::Value *temp2 = builder.CreateFCmpUEQ (complex_imag (rhs), fzero); + temp = builder.CreateAnd (temp, temp2); + builder.CreateCondBr (temp, real_mul, complex_mul); + + builder.SetInsertPoint(real_mul); + temp = builder.CreateFMul (complex_real (lhs), complex_real (rhs)); + llvm::Value *real_branch_ret = complex_new (temp, fzero); + builder.CreateBr (ret_block); llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4); llvm::Value *mlhs = llvm::UndefValue::get (vec4); llvm::Value *mrhs = mlhs; - - llvm::Value *temp = complex_real (lhs); + builder.SetInsertPoint (complex_mul); + temp = complex_real (lhs); mlhs = builder.CreateInsertElement (mlhs, temp, zero); mlhs = builder.CreateInsertElement (mlhs, temp, two); temp = complex_imag (lhs); @@ -1349,7 +1364,15 @@ tlhs = builder.CreateExtractElement (mres, two); trhs = builder.CreateExtractElement (mres, three); llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs); - fn.do_return (builder, complex_new (ret_real, ret_imag)); + llvm::Value *complex_branch_ret = complex_new (ret_real, ret_imag); + builder.CreateBr (ret_block); + + builder.SetInsertPoint (ret_block); + llvm::PHINode *merge = llvm::PHINode::Create(complex_t, 2); + builder.Insert (merge); + merge->addIncoming (real_branch_ret, real_mul); + merge->addIncoming (complex_branch_ret, complex_mul); + fn.do_return (builder, merge); } binary_ops[octave_value::op_mul].add_overload (fn); @@ -1381,10 +1404,25 @@ body = fn.new_block (); builder.SetInsertPoint (body); { + llvm::BasicBlock *complex_mul = fn.new_block ("complex_mul"); + llvm::BasicBlock *scalar_mul = fn.new_block ("scalar_mul"); + + llvm::Value *fzero = llvm::ConstantFP::get (scalar_t, 0); llvm::Value *lhs = fn.argument (builder, 0); - llvm::Value *tlhs = complex_new (lhs, lhs); llvm::Value *rhs = fn.argument (builder, 1); - fn.do_return (builder, builder.CreateFMul (tlhs, rhs)); + + llvm::Value *cmp = builder.CreateFCmpUEQ (complex_imag (rhs), fzero); + builder.CreateCondBr (cmp, scalar_mul, complex_mul); + + builder.SetInsertPoint (scalar_mul); + llvm::Value *temp = complex_real (rhs); + temp = builder.CreateFMul (lhs, temp); + fn.do_return (builder, complex_new (temp, fzero), false); + + + builder.SetInsertPoint (complex_mul); + temp = complex_new (lhs, lhs); + fn.do_return (builder, builder.CreateFMul (temp, rhs)); } binary_ops[octave_value::op_mul].add_overload (fn); binary_ops[octave_value::op_el_mul].add_overload (fn); @@ -2273,7 +2311,14 @@ } if (ov.is_complex_scalar ()) - return get_complex (); + { + Complex cv = ov.complex_value (); + + // We don't really represent complex values, instead we represent + // complex_or_scalar. If the imag value is zero, we assume a scalar. + if (cv.imag () == 0) + return get_complex (); + } return get_any (); }