changeset 14988:822d52bee973

More support for complex-complex and complex-scalar operations in JIT * src/pt-jit.cc (xisint, octave_jit_pow_scalar_scalar, octave_jit_pow_complex_complex, octave_jit_pow_complex_scalar, octave_jit_pow_scalar_complex, jit_typeinfo::mirror_binary, jit_typeinfo::complex_real, jit_typeinfo::complex_imag, jit_typeinfo::complex_new): New function. (jit_typeinfo::jit_typeinfo): Support more complex functionality. * src/pt-jit.h (jit_typeinfo::mirror_binary, jit_typeinfo::complex_real, jit_typeinfo::complex_imag, jit_typeinfo::complex_new): New declaration.
author Max Brister <max@2bass.com>
date Wed, 11 Jul 2012 14:42:56 -0500
parents 65f74f52886c
children dfaa8da46364
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 254 insertions(+), 65 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc	Tue Jul 10 21:25:51 2012 -0500
+++ b/src/pt-jit.cc	Wed Jul 11 14:42:56 2012 -0500
@@ -333,6 +333,48 @@
   return lhs / rhs;
 }
 
+// FIXME: CP form src/xpow.cc
+static inline int
+xisint (double x)
+{
+  return (D_NINT (x) == x
+          && ((x >= 0 && x < INT_MAX)
+              || (x <= 0 && x > INT_MIN)));
+}
+
+extern "C" Complex
+octave_jit_pow_scalar_scalar (double lhs, double rhs)
+{
+  // FIXME: almost CP from src/xpow.cc
+  if (lhs < 0.0 && ! xisint (rhs))
+    return std::pow (Complex (lhs), rhs);
+  return std::pow (lhs, rhs);
+}
+
+extern "C" Complex
+octave_jit_pow_complex_complex (Complex lhs, Complex rhs)
+{
+  if (lhs.imag () == 0 && rhs.imag () == 0)
+    return octave_jit_pow_scalar_scalar (lhs.real (), rhs.real ());
+  return std::pow (lhs, rhs);
+}
+
+extern "C" Complex
+octave_jit_pow_complex_scalar (Complex lhs, double rhs)
+{
+  if (lhs.imag () == 0)
+    return octave_jit_pow_scalar_scalar (lhs.real (), rhs);
+  return std::pow (lhs, rhs);
+}
+
+extern "C" Complex
+octave_jit_pow_scalar_complex (double lhs, Complex rhs)
+{
+  if (rhs.imag () == 0)
+    return octave_jit_pow_scalar_scalar (lhs, rhs.real ());
+  return std::pow (lhs, rhs);
+}
+
 extern "C" void
 octave_jit_print_matrix (jit_matrix *m)
 {
@@ -544,12 +586,12 @@
   // create types
   any = new_type ("any", 0, any_t);
   matrix = new_type ("matrix", any, matrix_t);
-  scalar = new_type ("scalar", any, scalar_t);
+  complex = new_type ("complex", any, complex_t);
+  scalar = new_type ("scalar", complex, scalar_t);
   range = new_type ("range", any, range_t);
   string = new_type ("string", any, string_t);
   boolean = new_type ("bool", any, bool_t);
   index = new_type ("index", any, index_t);
-  complex = new_type ("complex", any, complex_t);
 
   casts.resize (next_id + 1);
   identities.resize (next_id + 1, 0);
@@ -692,21 +734,25 @@
   llvm::verifyFunction (*fn);
 
   // ldiv is the same as div with the operators reversed
-  llvm::Function *div = fn;
-  fn = create_function ("octave_jit_ldiv_scalar_scalar", scalar, scalar,
-                        scalar);
-  body = llvm::BasicBlock::Create (context, "body", fn);
-  builder.SetInsertPoint (body);
+  fn = mirror_binary (fn);
   {
-    llvm::Value *ret = builder.CreateCall2 (div, ++fn->arg_begin (),
-                                            fn->arg_begin ());
-    builder.CreateRet (ret);
-
     jit_operation::overload ol (fn, true, scalar, scalar, scalar);
     binary_ops[octave_value::op_ldiv].add_overload (ol);
     binary_ops[octave_value::op_el_ldiv].add_overload (ol);
   }
-  llvm::verifyFunction (*fn);
+
+  // In general, the result of scalar ^ scalar is a complex number. We might be
+  // able to improve on this if we keep track of the range of values varaibles
+  // can take on.
+  fn = create_function ("octave_jit_pow_scalar_scalar", complex_ret, scalar_t,
+                        scalar_t);
+  engine->addGlobalMapping (fn, reinterpret_cast<void *> (octave_jit_pow_scalar_scalar));
+  {
+    jit_operation::overload ol (wrap_complex (fn), false, complex, scalar,
+                                scalar);
+    binary_ops[octave_value::op_pow].add_overload (ol);
+    binary_ops[octave_value::op_el_pow].add_overload (ol);
+  }
 
   // now for binary complex operations
   add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd);
@@ -734,32 +780,29 @@
     llvm::Value *mlhs = llvm::UndefValue::get (vec4);
     llvm::Value *mrhs = mlhs;
 
-    llvm::Value *temp = builder.CreateExtractElement (lhs, zero);
+    llvm::Value *temp = complex_real (lhs);
     mlhs = builder.CreateInsertElement (mlhs, temp, zero);
     mlhs = builder.CreateInsertElement (mlhs, temp, two);
-    temp = builder.CreateExtractElement (lhs, one);
+    temp = complex_imag (lhs);
     mlhs = builder.CreateInsertElement (mlhs, temp, one);
     mlhs = builder.CreateInsertElement (mlhs, temp, three);
 
-    temp = builder.CreateExtractElement (rhs, zero);
+    temp = complex_real (rhs);
     mrhs = builder.CreateInsertElement (mrhs, temp, zero);
     mrhs = builder.CreateInsertElement (mrhs, temp, three);
-    temp = builder.CreateExtractElement (rhs, one);
+    temp = complex_imag (rhs);
     mrhs = builder.CreateInsertElement (mrhs, temp, one);
     mrhs = builder.CreateInsertElement (mrhs, temp, two);
 
     llvm::Value *mres = builder.CreateFMul (mlhs, mrhs);
-    llvm::Value *ret = llvm::UndefValue::get (complex_t);
     llvm::Value *tlhs = builder.CreateExtractElement (mres, zero);
     llvm::Value *trhs = builder.CreateExtractElement (mres, one);
-    temp = builder.CreateFSub (tlhs, trhs);
-    ret = builder.CreateInsertElement (ret, temp, zero);
+    llvm::Value *ret_real = builder.CreateFSub (tlhs, trhs);
 
     tlhs = builder.CreateExtractElement (mres, two);
     trhs = builder.CreateExtractElement (mres, three);
-    temp = builder.CreateFAdd (tlhs, trhs);
-    ret = builder.CreateInsertElement (ret, temp, one);
-    builder.CreateRet (ret);
+    llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs);
+    builder.CreateRet (complex_new (ret_real, ret_imag));
 
     jit_operation::overload ol (fn, false, complex, complex, complex);
     binary_ops[octave_value::op_mul].add_overload (ol);
@@ -767,42 +810,6 @@
   }
   llvm::verifyFunction (*fn);
 
-  fn = create_function ("octave_jit_*_scalar_complex", complex, scalar,
-                        complex);
-  llvm::Function *mul_scalar_complex = fn;
-  body = llvm::BasicBlock::Create (context, "body", fn);
-  builder.SetInsertPoint (body);
-  {
-    llvm::Value *lhs = fn->arg_begin ();
-    llvm::Value *tlhs = llvm::UndefValue::get (complex_t);
-    tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (0));
-    tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (1));
-
-    llvm::Value *rhs = ++fn->arg_begin ();
-    builder.CreateRet (builder.CreateFMul (tlhs, rhs));
-
-    jit_operation::overload ol (fn, false, complex, scalar, complex);
-    binary_ops[octave_value::op_mul].add_overload (ol);
-    binary_ops[octave_value::op_el_mul].add_overload (ol);
-  }
-  llvm::verifyFunction (*fn);
-
-  fn = create_function ("octave_jit_*_complex_scalar", complex, complex,
-                        scalar);
-  body = llvm::BasicBlock::Create (context, "body", fn);
-  builder.SetInsertPoint (body);
-  {
-    llvm::Value *ret = builder.CreateCall2 (mul_scalar_complex,
-                                            ++fn->arg_begin (),
-                                            fn->arg_begin ());
-    builder.CreateRet (ret);
-
-    jit_operation::overload ol (fn, false, complex, complex,  scalar);
-    binary_ops[octave_value::op_mul].add_overload (ol);
-    binary_ops[octave_value::op_el_mul].add_overload (ol);
-  }
-  llvm::verifyFunction (*fn);
-
   llvm::Function *complex_div = create_function ("octave_jit_complex_div",
                                                  complex_ret, complex_ret,
                                                  complex_ret);
@@ -815,18 +822,114 @@
     binary_ops[octave_value::op_ldiv].add_overload (ol);
   }
 
-  fn = create_function ("octave_jit_\\_complex_complex", complex, complex,
+  fn = mirror_binary (complex_div);
+  {
+    jit_operation::overload ol (fn, true, complex, complex, complex);
+    binary_ops[octave_value::op_ldiv].add_overload (ol);
+    binary_ops[octave_value::op_el_ldiv].add_overload (ol);
+  }
+
+  fn = create_function ("octave_jit_pow_complex_complex", complex_ret,
+                        complex_ret, complex_ret);
+  engine->addGlobalMapping (fn, reinterpret_cast<void *> (octave_jit_pow_complex_complex));
+  {
+    jit_operation::overload ol (wrap_complex (fn), false, complex, complex,
+                                complex);
+    binary_ops[octave_value::op_pow].add_overload (ol);
+    binary_ops[octave_value::op_el_pow].add_overload (ol);
+  }
+
+  fn = create_function ("octave_jit_*_scalar_complex", complex, scalar,
+                        complex);
+  llvm::Function *mul_scalar_complex = fn;
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *lhs = fn->arg_begin ();
+    llvm::Value *tlhs = complex_new (lhs, lhs);
+    llvm::Value *rhs = ++fn->arg_begin ();
+    builder.CreateRet (builder.CreateFMul (tlhs, rhs));
+
+    jit_operation::overload ol (fn, false, complex, scalar, complex);
+    binary_ops[octave_value::op_mul].add_overload (ol);
+    binary_ops[octave_value::op_el_mul].add_overload (ol);
+  }
+  llvm::verifyFunction (*fn);
+
+  fn = mirror_binary (mul_scalar_complex);
+  {
+    jit_operation::overload ol (fn, false, complex, complex,  scalar);
+    binary_ops[octave_value::op_mul].add_overload (ol);
+    binary_ops[octave_value::op_el_mul].add_overload (ol);
+  }
+
+  fn = create_function ("octave_jit_+_scalar_complex", complex, scalar,
                         complex);
   body = llvm::BasicBlock::Create (context, "body", fn);
   builder.SetInsertPoint (body);
   {
-    builder.CreateRet (builder.CreateCall2 (complex_div, ++fn->arg_begin (),
-                                            fn->arg_begin ()));
-    jit_operation::overload ol (fn, true, complex, complex, complex);
-    binary_ops[octave_value::op_ldiv].add_overload (ol);
-    binary_ops[octave_value::op_el_ldiv].add_overload (ol);
+    llvm::Value *lhs = fn->arg_begin ();
+    llvm::Value *rhs = ++fn->arg_begin ();
+    llvm::Value *real = builder.CreateFAdd (lhs, complex_real (rhs));
+    builder.CreateRet (complex_real (rhs, real));
+    llvm::verifyFunction (*fn);
+
+    binary_ops[octave_value::op_add].add_overload (fn, false, complex, scalar,
+                                                   complex);
+    fn = mirror_binary (fn);
+    binary_ops[octave_value::op_add].add_overload (fn, false, complex, complex,
+                                                   scalar);
+  }
+
+  fn = create_function ("octave_jit_-_complex_scalar", complex, complex,
+                        scalar);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *lhs = fn->arg_begin ();
+    llvm::Value *rhs = ++fn->arg_begin ();
+    llvm::Value *real = builder.CreateFSub (complex_real (lhs), rhs);
+    builder.CreateRet (complex_real (lhs, real));
+    llvm::verifyFunction (*fn);
+
+    binary_ops[octave_value::op_sub].add_overload (fn, false, complex, complex,
+                                                   scalar);
   }
-  llvm::verifyFunction (*fn);
+
+  fn = create_function ("octave_jit_-_scalar_complex", complex, scalar,
+                        complex);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *lhs = fn->arg_begin ();
+    llvm::Value *rhs = ++fn->arg_begin ();
+    llvm::Value *real = builder.CreateFSub (lhs, complex_real (rhs));
+    builder.CreateRet (complex_real (rhs, real));
+    llvm::verifyFunction (*fn);
+
+    binary_ops[octave_value::op_sub].add_overload (fn, false, complex, scalar,
+                                                   complex);
+  }
+
+  fn = create_function ("octave_jit_pow_scalar_complex", complex_ret,
+                        scalar_t, complex_ret);
+  engine->addGlobalMapping (fn, reinterpret_cast<void *> (octave_jit_pow_scalar_complex));
+  {
+    jit_operation::overload ol (wrap_complex (fn), false, complex, scalar,
+                                complex);
+    binary_ops[octave_value::op_pow].add_overload (ol);
+    binary_ops[octave_value::op_el_pow].add_overload (ol);
+  }
+
+  fn = create_function ("octave_jit_pow_complex_scalar", complex_ret,
+                        complex_ret, scalar_t);
+  engine->addGlobalMapping (fn, reinterpret_cast<void *> (octave_jit_pow_complex_complex));
+  {
+    jit_operation::overload ol (wrap_complex (fn), false, complex, complex,
+                                scalar);
+    binary_ops[octave_value::op_pow].add_overload (ol);
+    binary_ops[octave_value::op_el_pow].add_overload (ol);
+  }
 
   // now for binary index operators
   add_binary_op (index, octave_value::op_add, llvm::Instruction::Add);
@@ -1177,6 +1280,27 @@
   casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex,
                                            any);
 
+  // cast complex <- scalar
+  fn = create_function ("octave_jit_cast_complex_scalar", complex, scalar);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0);
+    builder.CreateRet (complex_new (fn->arg_begin (), zero));
+    llvm::verifyFunction (*fn);
+  }
+  casts[complex->type_id ()].add_overload (fn, false, complex, scalar);
+
+  // cast scalar <- complex
+  fn = create_function ("octave_jit_cast_scalar_complex", scalar, complex);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    builder.CreateRet (complex_real (fn->arg_begin ()));
+    llvm::verifyFunction (*fn);
+  }
+  casts[scalar->type_id ()].add_overload (fn, false, scalar, complex);
+
   // cast any <- any
   fn = create_identity (any);
   casts[any->type_id ()].add_overload (fn, false, any, any);
@@ -1420,6 +1544,27 @@
 }
 
 llvm::Function *
+jit_typeinfo::mirror_binary (llvm::Function *fn)
+{
+  llvm::FunctionType *fn_type = fn->getFunctionType ();
+  llvm::Function *ret = create_function (fn->getName () + "_reverse",
+                                         fn_type->getReturnType (),
+                                         fn_type->getParamType (1),
+                                         fn_type->getParamType (0));
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", ret);
+  builder.SetInsertPoint (body);
+  llvm::Value *result = builder.CreateCall2 (fn, ++ret->arg_begin (),
+                                             ret->arg_begin ());
+  if (ret->getReturnType () == builder.getVoidTy ())
+    builder.CreateRetVoid ();
+  else
+    builder.CreateRet (result);
+
+  llvm::verifyFunction (*ret);
+  return ret;
+}
+
+llvm::Function *
 jit_typeinfo::wrap_complex (llvm::Function *wrap)
 {
   llvm::SmallVector<llvm::Type *, 5> new_args;
@@ -1490,6 +1635,38 @@
   return builder.CreateInsertElement (ret, imag, builder.getInt32 (1));
 }
 
+llvm::Value *
+jit_typeinfo::complex_real (llvm::Value *cx)
+{
+  return builder.CreateExtractElement (cx, builder.getInt32 (0));
+}
+
+llvm::Value *
+jit_typeinfo::complex_real (llvm::Value *cx, llvm::Value *real)
+{
+  return builder.CreateInsertElement (cx, real, builder.getInt32 (0));
+}
+
+llvm::Value *
+jit_typeinfo::complex_imag (llvm::Value *cx)
+{
+  return builder.CreateExtractElement (cx, builder.getInt32 (1));
+}
+
+llvm::Value *
+jit_typeinfo::complex_imag (llvm::Value *cx, llvm::Value *imag)
+{
+  return builder.CreateInsertElement (cx, imag, builder.getInt32 (1));
+}
+
+llvm::Value *
+jit_typeinfo::complex_new (llvm::Value *real, llvm::Value *imag)
+{
+  llvm::Value *ret = llvm::UndefValue::get (complex->to_llvm ());
+  ret = complex_real (ret, real);
+  return complex_imag (ret, imag);
+}
+
 jit_type *
 jit_typeinfo::do_type_of (const octave_value &ov) const
 {
--- a/src/pt-jit.h	Tue Jul 10 21:25:51 2012 -0500
+++ b/src/pt-jit.h	Wed Jul 11 14:42:56 2012 -0500
@@ -668,12 +668,24 @@
 
   octave_builtin *find_builtin (const std::string& name);
 
+  llvm::Function *mirror_binary (llvm::Function *fn);
+
   llvm::Function *wrap_complex (llvm::Function *wrap);
 
   llvm::Value *pack_complex (llvm::Value *cplx);
 
   llvm::Value *unpack_complex (llvm::Value *result);
 
+  llvm::Value *complex_real (llvm::Value *cx);
+
+  llvm::Value *complex_real (llvm::Value *cx, llvm::Value *real);
+
+  llvm::Value *complex_imag (llvm::Value *cx);
+
+  llvm::Value *complex_imag (llvm::Value *cx, llvm::Value *imag);
+
+  llvm::Value *complex_new (llvm::Value *real, llvm::Value *imag);
+
   static jit_typeinfo *instance;
 
   llvm::Module *module;