comparison libinterp/interp-core/jit-typeinfo.cc @ 15370:8355fddce815

Use sret and do not use save/restore stack (bug #37308) * jit-typeinfo.cc (octave_jit_grab_matrix, octave_jit_cast_matrix_any, octave_jit_paren_subsasgn_impl, octave_jit_paren_scalar_subsasgn, octave_jit_paren_subsasgn_matrix_range): Return matrix directly. (octave_jit_cast_range_any): Return range directly. (jit_function::jit_function): Maybe mark llvm function return as sret. (jit_function::call): Maybe mark llvm call sret and place allocas at function entry. (jit_function::do_return): Handle new parameter, verify. (jit_typeinfo::jit_typeinfo): Match C++ std::complex type better, pass jit_convetion::external explicitly, and disable right complex division. (jit_typeinfo::create_identity): Improve name. (jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): Handle changed complex format. * jit-typeinfo.h (jit_array::jit_array): New overload. (jit_type::mark_sret, jit_type::mark_pointer_arg): Remove default convention. (jit_function::do_return): Add verify parameter. * pt-jit.cc (jit_convert_llvm::convert_function): Store the jit_function. (jit_convert::visit): Call do_return if converting a function. * pt-jit.h (jit_convert_llvm::creating): New member variable.
author Max Brister <max@2bass.com>
date Wed, 12 Sep 2012 19:18:51 -0600
parents 3f43e9d6d86e
children 8ccb187b24e9
comparison
equal deleted inserted replaced
15369:715220d2b511 15370:8355fddce815
111 { 111 {
112 obv->grab (); 112 obv->grab ();
113 return obv; 113 return obv;
114 } 114 }
115 115
116 extern "C" void 116 extern "C" jit_matrix
117 octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m) 117 octave_jit_grab_matrix (jit_matrix *m)
118 { 118 {
119 *result = *m->array; 119 return *m->array;
120 } 120 }
121 121
122 extern "C" octave_base_value * 122 extern "C" octave_base_value *
123 octave_jit_cast_any_matrix (jit_matrix *m) 123 octave_jit_cast_any_matrix (jit_matrix *m)
124 { 124 {
128 delete m->array; 128 delete m->array;
129 129
130 return rep; 130 return rep;
131 } 131 }
132 132
133 extern "C" void 133 extern "C" jit_matrix
134 octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv) 134 octave_jit_cast_matrix_any (octave_base_value *obv)
135 { 135 {
136 NDArray m = obv->array_value (); 136 NDArray m = obv->array_value ();
137 *ret = m;
138 obv->release (); 137 obv->release ();
138 return m;
139 } 139 }
140 140
141 extern "C" octave_base_value * 141 extern "C" octave_base_value *
142 octave_jit_cast_any_range (jit_range *rng) 142 octave_jit_cast_any_range (jit_range *rng)
143 { 143 {
146 octave_base_value *rep = ret.internal_rep (); 146 octave_base_value *rep = ret.internal_rep ();
147 rep->grab (); 147 rep->grab ();
148 148
149 return rep; 149 return rep;
150 } 150 }
151 extern "C" void 151 extern "C" jit_range
152 octave_jit_cast_range_any (jit_range *ret, octave_base_value *obv) 152 octave_jit_cast_range_any (octave_base_value *obv)
153 { 153 {
154 154
155 jit_range r (obv->range_value ()); 155 jit_range r (obv->range_value ());
156 *ret = r;
157 obv->release (); 156 obv->release ();
157 return r;
158 } 158 }
159 159
160 extern "C" double 160 extern "C" double
161 octave_jit_cast_scalar_any (octave_base_value *obv) 161 octave_jit_cast_scalar_any (octave_base_value *obv)
162 { 162 {
226 { 226 {
227 gripe_library_execution_error (); 227 gripe_library_execution_error ();
228 } 228 }
229 } 229 }
230 230
231 extern "C" void 231 extern "C" jit_matrix
232 octave_jit_paren_subsasgn_impl (jit_matrix *ret, jit_matrix *mat, 232 octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index,
233 octave_idx_type index, double value) 233 double value)
234 { 234 {
235 NDArray *array = mat->array; 235 NDArray *array = mat->array;
236 if (array->nelem () < index) 236 if (array->nelem () < index)
237 array->resize1 (index); 237 array->resize1 (index);
238 238
239 double *data = array->fortran_vec (); 239 double *data = array->fortran_vec ();
240 data[index - 1] = value; 240 data[index - 1] = value;
241 241
242 mat->update (); 242 mat->update ();
243 *ret = *mat; 243 return *mat;
244 } 244 }
245 245
246 static void 246 static void
247 make_indices (double *indices, octave_idx_type idx_count, 247 make_indices (double *indices, octave_idx_type idx_count,
248 Array<idx_vector>& result) 248 Array<idx_vector>& result)
270 gripe_library_execution_error (); 270 gripe_library_execution_error ();
271 return 0; 271 return 0;
272 } 272 }
273 } 273 }
274 274
275 extern "C" void 275 extern "C" jit_matrix
276 octave_jit_paren_scalar_subsasgn (jit_matrix *ret, jit_matrix *mat, 276 octave_jit_paren_scalar_subsasgn (jit_matrix *mat, double *indices,
277 double *indices, octave_idx_type idx_count, 277 octave_idx_type idx_count, double value)
278 double value)
279 { 278 {
280 // FIXME: Replace this with a more optimal version 279 // FIXME: Replace this with a more optimal version
280 jit_matrix ret;
281 try 281 try
282 { 282 {
283 Array<idx_vector> idx; 283 Array<idx_vector> idx;
284 make_indices (indices, idx_count, idx); 284 make_indices (indices, idx_count, idx);
285 285
286 Matrix temp (1, 1); 286 Matrix temp (1, 1);
287 temp.xelem(0) = value; 287 temp.xelem(0) = value;
288 mat->array->assign (idx, temp); 288 mat->array->assign (idx, temp);
289 ret->update (mat->array); 289 ret.update (mat->array);
290 } 290 }
291 catch (const octave_execution_exception&) 291 catch (const octave_execution_exception&)
292 { 292 {
293 gripe_library_execution_error (); 293 gripe_library_execution_error ();
294 } 294 }
295 } 295
296 296 return ret;
297 extern "C" void 297 }
298 octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, 298
299 jit_range *index, double value) 299 extern "C" jit_matrix
300 octave_jit_paren_subsasgn_matrix_range (jit_matrix *mat, jit_range *index,
301 double value)
300 { 302 {
301 NDArray *array = mat->array; 303 NDArray *array = mat->array;
302 bool done = false; 304 bool done = false;
303 305
304 // optimize for the simple case (no resizing and no errors) 306 // optimize for the simple case (no resizing and no errors)
338 NDArray avalue (dim_vector (1, 1)); 340 NDArray avalue (dim_vector (1, 1));
339 avalue.xelem (0) = value; 341 avalue.xelem (0) = value;
340 array->assign (idx, avalue); 342 array->assign (idx, avalue);
341 } 343 }
342 344
343 result->update (array); 345 jit_matrix ret;
346 ret.update (array);
347 return ret;
344 } 348 }
345 349
346 extern "C" double 350 extern "C" double
347 octave_jit_end_matrix (jit_matrix *mat, octave_idx_type idx, 351 octave_jit_end_matrix (jit_matrix *mat, octave_idx_type idx,
348 octave_idx_type count) 352 octave_idx_type count)
560 // we mark all functinos as external linkage because this prevents llvm 564 // we mark all functinos as external linkage because this prevents llvm
561 // from getting rid of always inline functions 565 // from getting rid of always inline functions
562 llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false); 566 llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false);
563 llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, 567 llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
564 aname, module); 568 aname, module);
569
570 if (sret ())
571 llvm_function->addAttribute (1, llvm::Attribute::StructRet);
572
565 if (call_conv == jit_convention::internal) 573 if (call_conv == jit_convention::internal)
566 llvm_function->addFnAttr (llvm::Attribute::AlwaysInline); 574 llvm_function->addFnAttr (llvm::Attribute::AlwaysInline);
567 } 575 }
568 576
569 jit_function::jit_function (const jit_function& fn, jit_type *aresult, 577 jit_function::jit_function (const jit_function& fn, jit_type *aresult,
618 llvm::Function *stacksave 626 llvm::Function *stacksave
619 = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); 627 = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave);
620 llvm::SmallVector<llvm::Value *, 10> llvm_args; 628 llvm::SmallVector<llvm::Value *, 10> llvm_args;
621 llvm_args.reserve (in_args.size () + sret ()); 629 llvm_args.reserve (in_args.size () + sret ());
622 630
623 llvm::Value *sret_mem = 0; 631 llvm::BasicBlock *insert_block = builder.GetInsertBlock ();
624 llvm::Value *saved_stack = 0; 632 llvm::Function *parent = insert_block->getParent ();
633 assert (parent);
634
635 // we insert allocas inside the prelude block to prevent stack overflows
636 llvm::BasicBlock& prelude = parent->getEntryBlock ();
637 llvm::IRBuilder<> pre_builder (&prelude, prelude.begin ());
638
639 llvm::AllocaInst *sret_mem = 0;
625 if (sret ()) 640 if (sret ())
626 { 641 {
627 saved_stack = builder.CreateCall (stacksave); 642 sret_mem = pre_builder.CreateAlloca (mresult->packed_type (call_conv));
628 sret_mem = builder.CreateAlloca (mresult->packed_type (call_conv));
629 llvm_args.push_back (sret_mem); 643 llvm_args.push_back (sret_mem);
630 } 644 }
631 645
632 for (size_t i = 0; i < in_args.size (); ++i) 646 for (size_t i = 0; i < in_args.size (); ++i)
633 { 647 {
636 if (convert) 650 if (convert)
637 arg = convert (builder, arg); 651 arg = convert (builder, arg);
638 652
639 if (args[i]->pointer_arg (call_conv)) 653 if (args[i]->pointer_arg (call_conv))
640 { 654 {
641 if (! saved_stack) 655 llvm::Type *ty = args[i]->packed_type (call_conv);
642 saved_stack = builder.CreateCall (stacksave); 656 llvm::Value *alloca = pre_builder.CreateAlloca (ty);
643 657 builder.CreateStore (arg, alloca);
644 arg = builder.CreateAlloca (args[i]->to_llvm ()); 658 arg = alloca;
645 builder.CreateStore (in_args[i], arg);
646 } 659 }
647 660
648 llvm_args.push_back (arg); 661 llvm_args.push_back (arg);
649 } 662 }
650 663
651 llvm::Value *ret = builder.CreateCall (llvm_function, llvm_args); 664 llvm::CallInst *callinst = builder.CreateCall (llvm_function, llvm_args);
652 if (sret_mem) 665 llvm::Value *ret = callinst;
653 ret = builder.CreateLoad (sret_mem); 666
667 if (sret ())
668 {
669 callinst->addAttribute (1, llvm::Attribute::StructRet);
670 ret = builder.CreateLoad (sret_mem);
671 }
654 672
655 if (mresult) 673 if (mresult)
656 { 674 {
657 jit_type::convert_fn unpack = mresult->unpack (call_conv); 675 jit_type::convert_fn unpack = mresult->unpack (call_conv);
658 if (unpack) 676 if (unpack)
659 ret = unpack (builder, ret); 677 ret = unpack (builder, ret);
660 }
661
662 if (saved_stack)
663 {
664 llvm::Function *stackrestore
665 = llvm::Intrinsic::getDeclaration (module,
666 llvm::Intrinsic::stackrestore);
667 builder.CreateCall (stackrestore, saved_stack);
668 } 678 }
669 679
670 return ret; 680 return ret;
671 } 681 }
672 682
689 699
690 return iter; 700 return iter;
691 } 701 }
692 702
693 void 703 void
694 jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval) 704 jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval,
705 bool verify)
695 { 706 {
696 assert (! rval == ! mresult); 707 assert (! rval == ! mresult);
697 708
698 if (rval) 709 if (rval)
699 { 710 {
700 jit_type::convert_fn convert = mresult->pack (call_conv); 711 jit_type::convert_fn convert = mresult->pack (call_conv);
701 if (convert) 712 if (convert)
702 rval = convert (builder, rval); 713 rval = convert (builder, rval);
703 714
704 if (sret ()) 715 if (sret ())
705 builder.CreateStore (rval, llvm_function->arg_begin ()); 716 {
717 builder.CreateStore (rval, llvm_function->arg_begin ());
718 builder.CreateRetVoid ();
719 }
706 else 720 else
707 builder.CreateRet (rval); 721 builder.CreateRet (rval);
708 } 722 }
709 else 723 else
710 builder.CreateRetVoid (); 724 builder.CreateRetVoid ();
711 725
712 llvm::verifyFunction (*llvm_function); 726 if (verify)
727 llvm::verifyFunction (*llvm_function);
713 } 728 }
714 729
715 void 730 void
716 jit_function::do_add_mapping (llvm::ExecutionEngine *engine, void *fn) 731 jit_function::do_add_mapping (llvm::ExecutionEngine *engine, void *fn)
717 { 732 {
1030 1045
1031 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); 1046 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2);
1032 1047
1033 // complex_ret is what is passed to C functions in order to get calling 1048 // complex_ret is what is passed to C functions in order to get calling
1034 // convention right 1049 // convention right
1050 llvm::Type *cmplx_inner_cont[] = {scalar_t, scalar_t};
1051 llvm::StructType *cmplx_inner = llvm::StructType::create (cmplx_inner_cont);
1052
1035 complex_ret = llvm::StructType::create (context, "complex_ret"); 1053 complex_ret = llvm::StructType::create (context, "complex_ret");
1036 llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; 1054 {
1037 complex_ret->setBody (complex_ret_contents); 1055 llvm::Type *contents[] = {cmplx_inner};
1056 complex_ret->setBody (contents);
1057 }
1038 1058
1039 // create types 1059 // create types
1040 any = new_type ("any", 0, any_t); 1060 any = new_type ("any", 0, any_t);
1041 matrix = new_type ("matrix", any, matrix_t); 1061 matrix = new_type ("matrix", any, matrix_t);
1042 complex = new_type ("complex", any, complex_t); 1062 complex = new_type ("complex", any, complex_t);
1057 identities.resize (next_id + 1); 1077 identities.resize (next_id + 1);
1058 1078
1059 // specify calling conventions 1079 // specify calling conventions
1060 // FIXME: We should detect architecture and do something sane based on that 1080 // FIXME: We should detect architecture and do something sane based on that
1061 // here we assume x86 or x86_64 1081 // here we assume x86 or x86_64
1062 matrix->mark_sret (); 1082 matrix->mark_sret (jit_convention::external);
1063 matrix->mark_pointer_arg (); 1083 matrix->mark_pointer_arg (jit_convention::external);
1064 1084
1065 range->mark_sret (); 1085 range->mark_sret (jit_convention::external);
1066 range->mark_pointer_arg (); 1086 range->mark_pointer_arg (jit_convention::external);
1067 1087
1068 complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex); 1088 complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex);
1069 complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex); 1089 complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex);
1070 complex->set_packed_type (jit_convention::external, complex_ret); 1090 complex->set_packed_type (jit_convention::external, complex_ret);
1071 1091
1072 if (sizeof (void *) == 4) 1092 if (sizeof (void *) == 4)
1073 complex->mark_sret (); 1093 complex->mark_sret (jit_convention::external);
1074 1094
1075 paren_subsref_fn.initialize (module, engine); 1095 paren_subsref_fn.initialize (module, engine);
1076 paren_subsasgn_fn.initialize (module, engine); 1096 paren_subsasgn_fn.initialize (module, engine);
1077 1097
1078 // bind global variables 1098 // bind global variables
1331 complex_div.add_mapping (engine, &octave_jit_complex_div); 1351 complex_div.add_mapping (engine, &octave_jit_complex_div);
1332 complex_div.mark_can_error (); 1352 complex_div.mark_can_error ();
1333 binary_ops[octave_value::op_div].add_overload (fn); 1353 binary_ops[octave_value::op_div].add_overload (fn);
1334 binary_ops[octave_value::op_ldiv].add_overload (fn); 1354 binary_ops[octave_value::op_ldiv].add_overload (fn);
1335 1355
1336 fn = mirror_binary (complex_div); 1356 // fn = mirror_binary (complex_div);
1337 binary_ops[octave_value::op_ldiv].add_overload (fn); 1357 // binary_ops[octave_value::op_ldiv].add_overload (fn);
1338 binary_ops[octave_value::op_el_ldiv].add_overload (fn); 1358 // binary_ops[octave_value::op_el_ldiv].add_overload (fn);
1339 1359
1340 fn = create_function (jit_convention::external, 1360 fn = create_function (jit_convention::external,
1341 "octave_jit_pow_complex_complex", complex, complex, 1361 "octave_jit_pow_complex_complex", complex, complex,
1342 complex); 1362 complex);
1343 fn.add_mapping (engine, &octave_jit_pow_complex_complex); 1363 fn.add_mapping (engine, &octave_jit_pow_complex_complex);
1988 if (id >= identities.size ()) 2008 if (id >= identities.size ())
1989 identities.resize (id + 1); 2009 identities.resize (id + 1);
1990 2010
1991 if (! identities[id].valid ()) 2011 if (! identities[id].valid ())
1992 { 2012 {
1993 jit_function fn = create_function (jit_convention::internal, "id", type, 2013 std::stringstream name;
1994 type); 2014 name << "id_" << type->name ();
2015 jit_function fn = create_function (jit_convention::internal, name.str (),
2016 type, type);
2017
1995 llvm::BasicBlock *body = fn.new_block (); 2018 llvm::BasicBlock *body = fn.new_block ();
1996 builder.SetInsertPoint (body); 2019 builder.SetInsertPoint (body);
1997 fn.do_return (builder, fn.argument (builder, 0)); 2020 fn.do_return (builder, fn.argument (builder, 0));
1998 return identities[id] = fn; 2021 return identities[id] = fn;
1999 } 2022 }
2139 { 2162 {
2140 llvm::Type *complex_ret = instance->complex_ret; 2163 llvm::Type *complex_ret = instance->complex_ret;
2141 llvm::Value *real = bld.CreateExtractElement (cplx, bld.getInt32 (0)); 2164 llvm::Value *real = bld.CreateExtractElement (cplx, bld.getInt32 (0));
2142 llvm::Value *imag = bld.CreateExtractElement (cplx, bld.getInt32 (1)); 2165 llvm::Value *imag = bld.CreateExtractElement (cplx, bld.getInt32 (1));
2143 llvm::Value *ret = llvm::UndefValue::get (complex_ret); 2166 llvm::Value *ret = llvm::UndefValue::get (complex_ret);
2144 ret = bld.CreateInsertValue (ret, real, 0); 2167
2145 return bld.CreateInsertValue (ret, imag, 1); 2168 unsigned int re_idx[] = {0, 0};
2169 unsigned int im_idx[] = {0, 1};
2170 ret = bld.CreateInsertValue (ret, real, re_idx);
2171 return bld.CreateInsertValue (ret, imag, im_idx);
2146 } 2172 }
2147 2173
2148 llvm::Value * 2174 llvm::Value *
2149 jit_typeinfo::unpack_complex (llvm::IRBuilderD& bld, llvm::Value *result) 2175 jit_typeinfo::unpack_complex (llvm::IRBuilderD& bld, llvm::Value *result)
2150 { 2176 {
2177 unsigned int re_idx[] = {0, 0};
2178 unsigned int im_idx[] = {0, 1};
2179
2151 llvm::Type *complex_t = get_complex ()->to_llvm (); 2180 llvm::Type *complex_t = get_complex ()->to_llvm ();
2152 llvm::Value *real = bld.CreateExtractValue (result, 0); 2181 llvm::Value *real = bld.CreateExtractValue (result, re_idx);
2153 llvm::Value *imag = bld.CreateExtractValue (result, 1); 2182 llvm::Value *imag = bld.CreateExtractValue (result, im_idx);
2154 llvm::Value *ret = llvm::UndefValue::get (complex_t); 2183 llvm::Value *ret = llvm::UndefValue::get (complex_t);
2184
2155 ret = bld.CreateInsertElement (ret, real, bld.getInt32 (0)); 2185 ret = bld.CreateInsertElement (ret, real, bld.getInt32 (0));
2156 return bld.CreateInsertElement (ret, imag, bld.getInt32 (1)); 2186 return bld.CreateInsertElement (ret, imag, bld.getInt32 (1));
2157 } 2187 }
2158 2188
2159 llvm::Value * 2189 llvm::Value *