# HG changeset patch # User Max Brister # Date 1347499131 21600 # Node ID 8355fddce8152e9afe4abf0c07a74a06da7ade52 # Parent 715220d2b511d0a21a1fcc8f286cf70acc3bf543 Use sret and do not use save/restore stack (bug #37308) * jit-typeinfo.cc (octave_jit_grab_matrix, octave_jit_cast_matrix_any, octave_jit_paren_subsasgn_impl, octave_jit_paren_scalar_subsasgn, octave_jit_paren_subsasgn_matrix_range): Return matrix directly. (octave_jit_cast_range_any): Return range directly. (jit_function::jit_function): Maybe mark llvm function return as sret. (jit_function::call): Maybe mark llvm call sret and place allocas at function entry. (jit_function::do_return): Handle new parameter, verify. (jit_typeinfo::jit_typeinfo): Match C++ std::complex type better, pass jit_convetion::external explicitly, and disable right complex division. (jit_typeinfo::create_identity): Improve name. (jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): Handle changed complex format. * jit-typeinfo.h (jit_array::jit_array): New overload. (jit_type::mark_sret, jit_type::mark_pointer_arg): Remove default convention. (jit_function::do_return): Add verify parameter. * pt-jit.cc (jit_convert_llvm::convert_function): Store the jit_function. (jit_convert::visit): Call do_return if converting a function. * pt-jit.h (jit_convert_llvm::creating): New member variable. diff -r 715220d2b511 -r 8355fddce815 libinterp/interp-core/jit-typeinfo.cc --- a/libinterp/interp-core/jit-typeinfo.cc Wed Sep 12 20:06:05 2012 -0700 +++ b/libinterp/interp-core/jit-typeinfo.cc Wed Sep 12 19:18:51 2012 -0600 @@ -113,10 +113,10 @@ return obv; } -extern "C" void -octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m) +extern "C" jit_matrix +octave_jit_grab_matrix (jit_matrix *m) { - *result = *m->array; + return *m->array; } extern "C" octave_base_value * @@ -130,12 +130,12 @@ return rep; } -extern "C" void -octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv) +extern "C" jit_matrix +octave_jit_cast_matrix_any (octave_base_value *obv) { NDArray m = obv->array_value (); - *ret = m; obv->release (); + return m; } extern "C" octave_base_value * @@ -148,13 +148,13 @@ return rep; } -extern "C" void -octave_jit_cast_range_any (jit_range *ret, octave_base_value *obv) +extern "C" jit_range +octave_jit_cast_range_any (octave_base_value *obv) { jit_range r (obv->range_value ()); - *ret = r; obv->release (); + return r; } extern "C" double @@ -228,9 +228,9 @@ } } -extern "C" void -octave_jit_paren_subsasgn_impl (jit_matrix *ret, jit_matrix *mat, - octave_idx_type index, double value) +extern "C" jit_matrix +octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, + double value) { NDArray *array = mat->array; if (array->nelem () < index) @@ -240,7 +240,7 @@ data[index - 1] = value; mat->update (); - *ret = *mat; + return *mat; } static void @@ -272,12 +272,12 @@ } } -extern "C" void -octave_jit_paren_scalar_subsasgn (jit_matrix *ret, jit_matrix *mat, - double *indices, octave_idx_type idx_count, - double value) +extern "C" jit_matrix +octave_jit_paren_scalar_subsasgn (jit_matrix *mat, double *indices, + octave_idx_type idx_count, double value) { // FIXME: Replace this with a more optimal version + jit_matrix ret; try { Array idx; @@ -286,17 +286,19 @@ Matrix temp (1, 1); temp.xelem(0) = value; mat->array->assign (idx, temp); - ret->update (mat->array); + ret.update (mat->array); } catch (const octave_execution_exception&) { gripe_library_execution_error (); } + + return ret; } -extern "C" void -octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, - jit_range *index, double value) +extern "C" jit_matrix +octave_jit_paren_subsasgn_matrix_range (jit_matrix *mat, jit_range *index, + double value) { NDArray *array = mat->array; bool done = false; @@ -340,7 +342,9 @@ array->assign (idx, avalue); } - result->update (array); + jit_matrix ret; + ret.update (array); + return ret; } extern "C" double @@ -562,6 +566,10 @@ llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false); llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, aname, module); + + if (sret ()) + llvm_function->addAttribute (1, llvm::Attribute::StructRet); + if (call_conv == jit_convention::internal) llvm_function->addFnAttr (llvm::Attribute::AlwaysInline); } @@ -620,12 +628,18 @@ llvm::SmallVector llvm_args; llvm_args.reserve (in_args.size () + sret ()); - llvm::Value *sret_mem = 0; - llvm::Value *saved_stack = 0; + llvm::BasicBlock *insert_block = builder.GetInsertBlock (); + llvm::Function *parent = insert_block->getParent (); + assert (parent); + + // we insert allocas inside the prelude block to prevent stack overflows + llvm::BasicBlock& prelude = parent->getEntryBlock (); + llvm::IRBuilder<> pre_builder (&prelude, prelude.begin ()); + + llvm::AllocaInst *sret_mem = 0; if (sret ()) { - saved_stack = builder.CreateCall (stacksave); - sret_mem = builder.CreateAlloca (mresult->packed_type (call_conv)); + sret_mem = pre_builder.CreateAlloca (mresult->packed_type (call_conv)); llvm_args.push_back (sret_mem); } @@ -638,19 +652,23 @@ if (args[i]->pointer_arg (call_conv)) { - if (! saved_stack) - saved_stack = builder.CreateCall (stacksave); - - arg = builder.CreateAlloca (args[i]->to_llvm ()); - builder.CreateStore (in_args[i], arg); + llvm::Type *ty = args[i]->packed_type (call_conv); + llvm::Value *alloca = pre_builder.CreateAlloca (ty); + builder.CreateStore (arg, alloca); + arg = alloca; } llvm_args.push_back (arg); } - llvm::Value *ret = builder.CreateCall (llvm_function, llvm_args); - if (sret_mem) - ret = builder.CreateLoad (sret_mem); + llvm::CallInst *callinst = builder.CreateCall (llvm_function, llvm_args); + llvm::Value *ret = callinst; + + if (sret ()) + { + callinst->addAttribute (1, llvm::Attribute::StructRet); + ret = builder.CreateLoad (sret_mem); + } if (mresult) { @@ -659,14 +677,6 @@ ret = unpack (builder, ret); } - if (saved_stack) - { - llvm::Function *stackrestore - = llvm::Intrinsic::getDeclaration (module, - llvm::Intrinsic::stackrestore); - builder.CreateCall (stackrestore, saved_stack); - } - return ret; } @@ -691,7 +701,8 @@ } void -jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval) +jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval, + bool verify) { assert (! rval == ! mresult); @@ -702,14 +713,18 @@ rval = convert (builder, rval); if (sret ()) - builder.CreateStore (rval, llvm_function->arg_begin ()); + { + builder.CreateStore (rval, llvm_function->arg_begin ()); + builder.CreateRetVoid (); + } else builder.CreateRet (rval); } else builder.CreateRetVoid (); - llvm::verifyFunction (*llvm_function); + if (verify) + llvm::verifyFunction (*llvm_function); } void @@ -1032,9 +1047,14 @@ // complex_ret is what is passed to C functions in order to get calling // convention right + llvm::Type *cmplx_inner_cont[] = {scalar_t, scalar_t}; + llvm::StructType *cmplx_inner = llvm::StructType::create (cmplx_inner_cont); + complex_ret = llvm::StructType::create (context, "complex_ret"); - llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; - complex_ret->setBody (complex_ret_contents); + { + llvm::Type *contents[] = {cmplx_inner}; + complex_ret->setBody (contents); + } // create types any = new_type ("any", 0, any_t); @@ -1059,18 +1079,18 @@ // specify calling conventions // FIXME: We should detect architecture and do something sane based on that // here we assume x86 or x86_64 - matrix->mark_sret (); - matrix->mark_pointer_arg (); + matrix->mark_sret (jit_convention::external); + matrix->mark_pointer_arg (jit_convention::external); - range->mark_sret (); - range->mark_pointer_arg (); + range->mark_sret (jit_convention::external); + range->mark_pointer_arg (jit_convention::external); complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex); complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex); complex->set_packed_type (jit_convention::external, complex_ret); if (sizeof (void *) == 4) - complex->mark_sret (); + complex->mark_sret (jit_convention::external); paren_subsref_fn.initialize (module, engine); paren_subsasgn_fn.initialize (module, engine); @@ -1333,9 +1353,9 @@ binary_ops[octave_value::op_div].add_overload (fn); binary_ops[octave_value::op_ldiv].add_overload (fn); - fn = mirror_binary (complex_div); - binary_ops[octave_value::op_ldiv].add_overload (fn); - binary_ops[octave_value::op_el_ldiv].add_overload (fn); + // fn = mirror_binary (complex_div); + // binary_ops[octave_value::op_ldiv].add_overload (fn); + // binary_ops[octave_value::op_el_ldiv].add_overload (fn); fn = create_function (jit_convention::external, "octave_jit_pow_complex_complex", complex, complex, @@ -1990,8 +2010,11 @@ if (! identities[id].valid ()) { - jit_function fn = create_function (jit_convention::internal, "id", type, - type); + std::stringstream name; + name << "id_" << type->name (); + jit_function fn = create_function (jit_convention::internal, name.str (), + type, type); + llvm::BasicBlock *body = fn.new_block (); builder.SetInsertPoint (body); fn.do_return (builder, fn.argument (builder, 0)); @@ -2141,17 +2164,24 @@ llvm::Value *real = bld.CreateExtractElement (cplx, bld.getInt32 (0)); llvm::Value *imag = bld.CreateExtractElement (cplx, bld.getInt32 (1)); llvm::Value *ret = llvm::UndefValue::get (complex_ret); - ret = bld.CreateInsertValue (ret, real, 0); - return bld.CreateInsertValue (ret, imag, 1); + + unsigned int re_idx[] = {0, 0}; + unsigned int im_idx[] = {0, 1}; + ret = bld.CreateInsertValue (ret, real, re_idx); + return bld.CreateInsertValue (ret, imag, im_idx); } llvm::Value * jit_typeinfo::unpack_complex (llvm::IRBuilderD& bld, llvm::Value *result) { + unsigned int re_idx[] = {0, 0}; + unsigned int im_idx[] = {0, 1}; + llvm::Type *complex_t = get_complex ()->to_llvm (); - llvm::Value *real = bld.CreateExtractValue (result, 0); - llvm::Value *imag = bld.CreateExtractValue (result, 1); + llvm::Value *real = bld.CreateExtractValue (result, re_idx); + llvm::Value *imag = bld.CreateExtractValue (result, im_idx); llvm::Value *ret = llvm::UndefValue::get (complex_t); + ret = bld.CreateInsertElement (ret, real, bld.getInt32 (0)); return bld.CreateInsertElement (ret, imag, bld.getInt32 (1)); } diff -r 715220d2b511 -r 8355fddce815 libinterp/interp-core/jit-typeinfo.h --- a/libinterp/interp-core/jit-typeinfo.h Wed Sep 12 20:06:05 2012 -0700 +++ b/libinterp/interp-core/jit-typeinfo.h Wed Sep 12 19:18:51 2012 -0600 @@ -66,6 +66,8 @@ struct jit_array { + jit_array () : array (0) {} + jit_array (T& from) : array (new T (from)) { update (); @@ -161,7 +163,7 @@ // retval. (on the stack) bool sret (jit_convention::type cc) const { return msret[cc]; } - void mark_sret (jit_convention::type cc = jit_convention::external) + void mark_sret (jit_convention::type cc) { msret[cc] = true; } // A function like: void foo (mytype arg0) @@ -169,7 +171,7 @@ // Basically just pass by reference. bool pointer_arg (jit_convention::type cc) const { return mpointer_arg[cc]; } - void mark_pointer_arg (jit_convention::type cc = jit_convention::external) + void mark_pointer_arg (jit_convention::type cc) { mpointer_arg[cc] = true; } // Convert into an equivalent form before calling. For example, complex is @@ -278,7 +280,8 @@ llvm::Value *argument (llvm::IRBuilderD& builder, size_t idx) const; - void do_return (llvm::IRBuilderD& builder, llvm::Value *rval = 0); + void do_return (llvm::IRBuilderD& builder, llvm::Value *rval = 0, + bool verify = true); llvm::Function *to_llvm (void) const { return llvm_function; } diff -r 715220d2b511 -r 8355fddce815 libinterp/interp-core/pt-jit.cc --- a/libinterp/interp-core/pt-jit.cc Wed Sep 12 20:06:05 2012 -0700 +++ b/libinterp/interp-core/pt-jit.cc Wed Sep 12 19:18:51 2012 -0600 @@ -1075,8 +1075,8 @@ jit_return *ret = dynamic_cast (final_block->back ()); assert (ret); - jit_function creating = jit_function (module, jit_convention::internal, - "foobar", ret->result_type (), args); + creating = jit_function (module, jit_convention::internal, + "foobar", ret->result_type (), args); function = creating.to_llvm (); try @@ -1280,10 +1280,16 @@ jit_convert_llvm::visit (jit_return& ret) { jit_value *res = ret.result (); - if (res) - builder.CreateRet (res->to_llvm ()); + + if (converting_function) + creating.do_return (builder, res->to_llvm (), false); else - builder.CreateRetVoid (); + { + if (res) + builder.CreateRet (res->to_llvm ()); + else + builder.CreateRetVoid (); + } } void diff -r 715220d2b511 -r 8355fddce815 libinterp/interp-core/pt-jit.h --- a/libinterp/interp-core/pt-jit.h Wed Sep 12 20:06:05 2012 -0700 +++ b/libinterp/interp-core/pt-jit.h Wed Sep 12 19:18:51 2012 -0600 @@ -276,6 +276,9 @@ bool converting_function; + // only used if we are converting a function + jit_function creating; + llvm::Function *function; llvm::BasicBlock *prelude;