# HG changeset patch # User Max Brister # Date 1342035776 18000 # Node ID 822d52bee9731b1162b6dd806837f9c9fbf1ac86 # Parent 65f74f52886cc8c30f2a946646f8435c11a12133 More support for complex-complex and complex-scalar operations in JIT * src/pt-jit.cc (xisint, octave_jit_pow_scalar_scalar, octave_jit_pow_complex_complex, octave_jit_pow_complex_scalar, octave_jit_pow_scalar_complex, jit_typeinfo::mirror_binary, jit_typeinfo::complex_real, jit_typeinfo::complex_imag, jit_typeinfo::complex_new): New function. (jit_typeinfo::jit_typeinfo): Support more complex functionality. * src/pt-jit.h (jit_typeinfo::mirror_binary, jit_typeinfo::complex_real, jit_typeinfo::complex_imag, jit_typeinfo::complex_new): New declaration. diff -r 65f74f52886c -r 822d52bee973 src/pt-jit.cc --- a/src/pt-jit.cc Tue Jul 10 21:25:51 2012 -0500 +++ b/src/pt-jit.cc Wed Jul 11 14:42:56 2012 -0500 @@ -333,6 +333,48 @@ return lhs / rhs; } +// FIXME: CP form src/xpow.cc +static inline int +xisint (double x) +{ + return (D_NINT (x) == x + && ((x >= 0 && x < INT_MAX) + || (x <= 0 && x > INT_MIN))); +} + +extern "C" Complex +octave_jit_pow_scalar_scalar (double lhs, double rhs) +{ + // FIXME: almost CP from src/xpow.cc + if (lhs < 0.0 && ! xisint (rhs)) + return std::pow (Complex (lhs), rhs); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_complex_complex (Complex lhs, Complex rhs) +{ + if (lhs.imag () == 0 && rhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs.real (), rhs.real ()); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_complex_scalar (Complex lhs, double rhs) +{ + if (lhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs.real (), rhs); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_scalar_complex (double lhs, Complex rhs) +{ + if (rhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs, rhs.real ()); + return std::pow (lhs, rhs); +} + extern "C" void octave_jit_print_matrix (jit_matrix *m) { @@ -544,12 +586,12 @@ // create types any = new_type ("any", 0, any_t); matrix = new_type ("matrix", any, matrix_t); - scalar = new_type ("scalar", any, scalar_t); + complex = new_type ("complex", any, complex_t); + scalar = new_type ("scalar", complex, scalar_t); range = new_type ("range", any, range_t); string = new_type ("string", any, string_t); boolean = new_type ("bool", any, bool_t); index = new_type ("index", any, index_t); - complex = new_type ("complex", any, complex_t); casts.resize (next_id + 1); identities.resize (next_id + 1, 0); @@ -692,21 +734,25 @@ llvm::verifyFunction (*fn); // ldiv is the same as div with the operators reversed - llvm::Function *div = fn; - fn = create_function ("octave_jit_ldiv_scalar_scalar", scalar, scalar, - scalar); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); + fn = mirror_binary (fn); { - llvm::Value *ret = builder.CreateCall2 (div, ++fn->arg_begin (), - fn->arg_begin ()); - builder.CreateRet (ret); - jit_operation::overload ol (fn, true, scalar, scalar, scalar); binary_ops[octave_value::op_ldiv].add_overload (ol); binary_ops[octave_value::op_el_ldiv].add_overload (ol); } - llvm::verifyFunction (*fn); + + // In general, the result of scalar ^ scalar is a complex number. We might be + // able to improve on this if we keep track of the range of values varaibles + // can take on. + fn = create_function ("octave_jit_pow_scalar_scalar", complex_ret, scalar_t, + scalar_t); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_scalar_scalar)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, scalar, + scalar); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } // now for binary complex operations add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd); @@ -734,32 +780,29 @@ llvm::Value *mlhs = llvm::UndefValue::get (vec4); llvm::Value *mrhs = mlhs; - llvm::Value *temp = builder.CreateExtractElement (lhs, zero); + llvm::Value *temp = complex_real (lhs); mlhs = builder.CreateInsertElement (mlhs, temp, zero); mlhs = builder.CreateInsertElement (mlhs, temp, two); - temp = builder.CreateExtractElement (lhs, one); + temp = complex_imag (lhs); mlhs = builder.CreateInsertElement (mlhs, temp, one); mlhs = builder.CreateInsertElement (mlhs, temp, three); - temp = builder.CreateExtractElement (rhs, zero); + temp = complex_real (rhs); mrhs = builder.CreateInsertElement (mrhs, temp, zero); mrhs = builder.CreateInsertElement (mrhs, temp, three); - temp = builder.CreateExtractElement (rhs, one); + temp = complex_imag (rhs); mrhs = builder.CreateInsertElement (mrhs, temp, one); mrhs = builder.CreateInsertElement (mrhs, temp, two); llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); - llvm::Value *ret = llvm::UndefValue::get (complex_t); llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); llvm::Value *trhs = builder.CreateExtractElement (mres, one); - temp = builder.CreateFSub (tlhs, trhs); - ret = builder.CreateInsertElement (ret, temp, zero); + llvm::Value *ret_real = builder.CreateFSub (tlhs, trhs); tlhs = builder.CreateExtractElement (mres, two); trhs = builder.CreateExtractElement (mres, three); - temp = builder.CreateFAdd (tlhs, trhs); - ret = builder.CreateInsertElement (ret, temp, one); - builder.CreateRet (ret); + llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs); + builder.CreateRet (complex_new (ret_real, ret_imag)); jit_operation::overload ol (fn, false, complex, complex, complex); binary_ops[octave_value::op_mul].add_overload (ol); @@ -767,42 +810,6 @@ } llvm::verifyFunction (*fn); - fn = create_function ("octave_jit_*_scalar_complex", complex, scalar, - complex); - llvm::Function *mul_scalar_complex = fn; - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - { - llvm::Value *lhs = fn->arg_begin (); - llvm::Value *tlhs = llvm::UndefValue::get (complex_t); - tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (0)); - tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (1)); - - llvm::Value *rhs = ++fn->arg_begin (); - builder.CreateRet (builder.CreateFMul (tlhs, rhs)); - - jit_operation::overload ol (fn, false, complex, scalar, complex); - binary_ops[octave_value::op_mul].add_overload (ol); - binary_ops[octave_value::op_el_mul].add_overload (ol); - } - llvm::verifyFunction (*fn); - - fn = create_function ("octave_jit_*_complex_scalar", complex, complex, - scalar); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - { - llvm::Value *ret = builder.CreateCall2 (mul_scalar_complex, - ++fn->arg_begin (), - fn->arg_begin ()); - builder.CreateRet (ret); - - jit_operation::overload ol (fn, false, complex, complex, scalar); - binary_ops[octave_value::op_mul].add_overload (ol); - binary_ops[octave_value::op_el_mul].add_overload (ol); - } - llvm::verifyFunction (*fn); - llvm::Function *complex_div = create_function ("octave_jit_complex_div", complex_ret, complex_ret, complex_ret); @@ -815,18 +822,114 @@ binary_ops[octave_value::op_ldiv].add_overload (ol); } - fn = create_function ("octave_jit_\\_complex_complex", complex, complex, + fn = mirror_binary (complex_div); + { + jit_operation::overload ol (fn, true, complex, complex, complex); + binary_ops[octave_value::op_ldiv].add_overload (ol); + binary_ops[octave_value::op_el_ldiv].add_overload (ol); + } + + fn = create_function ("octave_jit_pow_complex_complex", complex_ret, + complex_ret, complex_ret); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_complex_complex)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, complex, + complex); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } + + fn = create_function ("octave_jit_*_scalar_complex", complex, scalar, + complex); + llvm::Function *mul_scalar_complex = fn; + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *tlhs = complex_new (lhs, lhs); + llvm::Value *rhs = ++fn->arg_begin (); + builder.CreateRet (builder.CreateFMul (tlhs, rhs)); + + jit_operation::overload ol (fn, false, complex, scalar, complex); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + llvm::verifyFunction (*fn); + + fn = mirror_binary (mul_scalar_complex); + { + jit_operation::overload ol (fn, false, complex, complex, scalar); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + + fn = create_function ("octave_jit_+_scalar_complex", complex, scalar, complex); body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { - builder.CreateRet (builder.CreateCall2 (complex_div, ++fn->arg_begin (), - fn->arg_begin ())); - jit_operation::overload ol (fn, true, complex, complex, complex); - binary_ops[octave_value::op_ldiv].add_overload (ol); - binary_ops[octave_value::op_el_ldiv].add_overload (ol); + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *real = builder.CreateFAdd (lhs, complex_real (rhs)); + builder.CreateRet (complex_real (rhs, real)); + llvm::verifyFunction (*fn); + + binary_ops[octave_value::op_add].add_overload (fn, false, complex, scalar, + complex); + fn = mirror_binary (fn); + binary_ops[octave_value::op_add].add_overload (fn, false, complex, complex, + scalar); + } + + fn = create_function ("octave_jit_-_complex_scalar", complex, complex, + scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *real = builder.CreateFSub (complex_real (lhs), rhs); + builder.CreateRet (complex_real (lhs, real)); + llvm::verifyFunction (*fn); + + binary_ops[octave_value::op_sub].add_overload (fn, false, complex, complex, + scalar); } - llvm::verifyFunction (*fn); + + fn = create_function ("octave_jit_-_scalar_complex", complex, scalar, + complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *real = builder.CreateFSub (lhs, complex_real (rhs)); + builder.CreateRet (complex_real (rhs, real)); + llvm::verifyFunction (*fn); + + binary_ops[octave_value::op_sub].add_overload (fn, false, complex, scalar, + complex); + } + + fn = create_function ("octave_jit_pow_scalar_complex", complex_ret, + scalar_t, complex_ret); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_scalar_complex)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, scalar, + complex); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } + + fn = create_function ("octave_jit_pow_complex_scalar", complex_ret, + complex_ret, scalar_t); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_complex_complex)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, complex, + scalar); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } // now for binary index operators add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); @@ -1177,6 +1280,27 @@ casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex, any); + // cast complex <- scalar + fn = create_function ("octave_jit_cast_complex_scalar", complex, scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); + builder.CreateRet (complex_new (fn->arg_begin (), zero)); + llvm::verifyFunction (*fn); + } + casts[complex->type_id ()].add_overload (fn, false, complex, scalar); + + // cast scalar <- complex + fn = create_function ("octave_jit_cast_scalar_complex", scalar, complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + builder.CreateRet (complex_real (fn->arg_begin ())); + llvm::verifyFunction (*fn); + } + casts[scalar->type_id ()].add_overload (fn, false, scalar, complex); + // cast any <- any fn = create_identity (any); casts[any->type_id ()].add_overload (fn, false, any, any); @@ -1420,6 +1544,27 @@ } llvm::Function * +jit_typeinfo::mirror_binary (llvm::Function *fn) +{ + llvm::FunctionType *fn_type = fn->getFunctionType (); + llvm::Function *ret = create_function (fn->getName () + "_reverse", + fn_type->getReturnType (), + fn_type->getParamType (1), + fn_type->getParamType (0)); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", ret); + builder.SetInsertPoint (body); + llvm::Value *result = builder.CreateCall2 (fn, ++ret->arg_begin (), + ret->arg_begin ()); + if (ret->getReturnType () == builder.getVoidTy ()) + builder.CreateRetVoid (); + else + builder.CreateRet (result); + + llvm::verifyFunction (*ret); + return ret; +} + +llvm::Function * jit_typeinfo::wrap_complex (llvm::Function *wrap) { llvm::SmallVector new_args; @@ -1490,6 +1635,38 @@ return builder.CreateInsertElement (ret, imag, builder.getInt32 (1)); } +llvm::Value * +jit_typeinfo::complex_real (llvm::Value *cx) +{ + return builder.CreateExtractElement (cx, builder.getInt32 (0)); +} + +llvm::Value * +jit_typeinfo::complex_real (llvm::Value *cx, llvm::Value *real) +{ + return builder.CreateInsertElement (cx, real, builder.getInt32 (0)); +} + +llvm::Value * +jit_typeinfo::complex_imag (llvm::Value *cx) +{ + return builder.CreateExtractElement (cx, builder.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_imag (llvm::Value *cx, llvm::Value *imag) +{ + return builder.CreateInsertElement (cx, imag, builder.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_new (llvm::Value *real, llvm::Value *imag) +{ + llvm::Value *ret = llvm::UndefValue::get (complex->to_llvm ()); + ret = complex_real (ret, real); + return complex_imag (ret, imag); +} + jit_type * jit_typeinfo::do_type_of (const octave_value &ov) const { diff -r 65f74f52886c -r 822d52bee973 src/pt-jit.h --- a/src/pt-jit.h Tue Jul 10 21:25:51 2012 -0500 +++ b/src/pt-jit.h Wed Jul 11 14:42:56 2012 -0500 @@ -668,12 +668,24 @@ octave_builtin *find_builtin (const std::string& name); + llvm::Function *mirror_binary (llvm::Function *fn); + llvm::Function *wrap_complex (llvm::Function *wrap); llvm::Value *pack_complex (llvm::Value *cplx); llvm::Value *unpack_complex (llvm::Value *result); + llvm::Value *complex_real (llvm::Value *cx); + + llvm::Value *complex_real (llvm::Value *cx, llvm::Value *real); + + llvm::Value *complex_imag (llvm::Value *cx); + + llvm::Value *complex_imag (llvm::Value *cx, llvm::Value *imag); + + llvm::Value *complex_new (llvm::Value *real, llvm::Value *imag); + static jit_typeinfo *instance; llvm::Module *module;