Mercurial > octave-nkf
comparison libinterp/interp-core/jit-typeinfo.cc @ 15370:8355fddce815
Use sret and do not use save/restore stack (bug #37308)
* jit-typeinfo.cc (octave_jit_grab_matrix, octave_jit_cast_matrix_any,
octave_jit_paren_subsasgn_impl, octave_jit_paren_scalar_subsasgn,
octave_jit_paren_subsasgn_matrix_range): Return matrix directly.
(octave_jit_cast_range_any): Return range directly.
(jit_function::jit_function): Maybe mark llvm function return as sret.
(jit_function::call): Maybe mark llvm call sret and place allocas at function
entry.
(jit_function::do_return): Handle new parameter, verify.
(jit_typeinfo::jit_typeinfo): Match C++ std::complex type better, pass
jit_convetion::external explicitly, and disable right complex division.
(jit_typeinfo::create_identity): Improve name.
(jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): Handle changed
complex format.
* jit-typeinfo.h (jit_array::jit_array): New overload.
(jit_type::mark_sret, jit_type::mark_pointer_arg): Remove default convention.
(jit_function::do_return): Add verify parameter.
* pt-jit.cc (jit_convert_llvm::convert_function): Store the jit_function.
(jit_convert::visit): Call do_return if converting a function.
* pt-jit.h (jit_convert_llvm::creating): New member variable.
author | Max Brister <max@2bass.com> |
---|---|
date | Wed, 12 Sep 2012 19:18:51 -0600 |
parents | 3f43e9d6d86e |
children | 8ccb187b24e9 |
comparison
equal
deleted
inserted
replaced
15369:715220d2b511 | 15370:8355fddce815 |
---|---|
111 { | 111 { |
112 obv->grab (); | 112 obv->grab (); |
113 return obv; | 113 return obv; |
114 } | 114 } |
115 | 115 |
116 extern "C" void | 116 extern "C" jit_matrix |
117 octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m) | 117 octave_jit_grab_matrix (jit_matrix *m) |
118 { | 118 { |
119 *result = *m->array; | 119 return *m->array; |
120 } | 120 } |
121 | 121 |
122 extern "C" octave_base_value * | 122 extern "C" octave_base_value * |
123 octave_jit_cast_any_matrix (jit_matrix *m) | 123 octave_jit_cast_any_matrix (jit_matrix *m) |
124 { | 124 { |
128 delete m->array; | 128 delete m->array; |
129 | 129 |
130 return rep; | 130 return rep; |
131 } | 131 } |
132 | 132 |
133 extern "C" void | 133 extern "C" jit_matrix |
134 octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv) | 134 octave_jit_cast_matrix_any (octave_base_value *obv) |
135 { | 135 { |
136 NDArray m = obv->array_value (); | 136 NDArray m = obv->array_value (); |
137 *ret = m; | |
138 obv->release (); | 137 obv->release (); |
138 return m; | |
139 } | 139 } |
140 | 140 |
141 extern "C" octave_base_value * | 141 extern "C" octave_base_value * |
142 octave_jit_cast_any_range (jit_range *rng) | 142 octave_jit_cast_any_range (jit_range *rng) |
143 { | 143 { |
146 octave_base_value *rep = ret.internal_rep (); | 146 octave_base_value *rep = ret.internal_rep (); |
147 rep->grab (); | 147 rep->grab (); |
148 | 148 |
149 return rep; | 149 return rep; |
150 } | 150 } |
151 extern "C" void | 151 extern "C" jit_range |
152 octave_jit_cast_range_any (jit_range *ret, octave_base_value *obv) | 152 octave_jit_cast_range_any (octave_base_value *obv) |
153 { | 153 { |
154 | 154 |
155 jit_range r (obv->range_value ()); | 155 jit_range r (obv->range_value ()); |
156 *ret = r; | |
157 obv->release (); | 156 obv->release (); |
157 return r; | |
158 } | 158 } |
159 | 159 |
160 extern "C" double | 160 extern "C" double |
161 octave_jit_cast_scalar_any (octave_base_value *obv) | 161 octave_jit_cast_scalar_any (octave_base_value *obv) |
162 { | 162 { |
226 { | 226 { |
227 gripe_library_execution_error (); | 227 gripe_library_execution_error (); |
228 } | 228 } |
229 } | 229 } |
230 | 230 |
231 extern "C" void | 231 extern "C" jit_matrix |
232 octave_jit_paren_subsasgn_impl (jit_matrix *ret, jit_matrix *mat, | 232 octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, |
233 octave_idx_type index, double value) | 233 double value) |
234 { | 234 { |
235 NDArray *array = mat->array; | 235 NDArray *array = mat->array; |
236 if (array->nelem () < index) | 236 if (array->nelem () < index) |
237 array->resize1 (index); | 237 array->resize1 (index); |
238 | 238 |
239 double *data = array->fortran_vec (); | 239 double *data = array->fortran_vec (); |
240 data[index - 1] = value; | 240 data[index - 1] = value; |
241 | 241 |
242 mat->update (); | 242 mat->update (); |
243 *ret = *mat; | 243 return *mat; |
244 } | 244 } |
245 | 245 |
246 static void | 246 static void |
247 make_indices (double *indices, octave_idx_type idx_count, | 247 make_indices (double *indices, octave_idx_type idx_count, |
248 Array<idx_vector>& result) | 248 Array<idx_vector>& result) |
270 gripe_library_execution_error (); | 270 gripe_library_execution_error (); |
271 return 0; | 271 return 0; |
272 } | 272 } |
273 } | 273 } |
274 | 274 |
275 extern "C" void | 275 extern "C" jit_matrix |
276 octave_jit_paren_scalar_subsasgn (jit_matrix *ret, jit_matrix *mat, | 276 octave_jit_paren_scalar_subsasgn (jit_matrix *mat, double *indices, |
277 double *indices, octave_idx_type idx_count, | 277 octave_idx_type idx_count, double value) |
278 double value) | |
279 { | 278 { |
280 // FIXME: Replace this with a more optimal version | 279 // FIXME: Replace this with a more optimal version |
280 jit_matrix ret; | |
281 try | 281 try |
282 { | 282 { |
283 Array<idx_vector> idx; | 283 Array<idx_vector> idx; |
284 make_indices (indices, idx_count, idx); | 284 make_indices (indices, idx_count, idx); |
285 | 285 |
286 Matrix temp (1, 1); | 286 Matrix temp (1, 1); |
287 temp.xelem(0) = value; | 287 temp.xelem(0) = value; |
288 mat->array->assign (idx, temp); | 288 mat->array->assign (idx, temp); |
289 ret->update (mat->array); | 289 ret.update (mat->array); |
290 } | 290 } |
291 catch (const octave_execution_exception&) | 291 catch (const octave_execution_exception&) |
292 { | 292 { |
293 gripe_library_execution_error (); | 293 gripe_library_execution_error (); |
294 } | 294 } |
295 } | 295 |
296 | 296 return ret; |
297 extern "C" void | 297 } |
298 octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, | 298 |
299 jit_range *index, double value) | 299 extern "C" jit_matrix |
300 octave_jit_paren_subsasgn_matrix_range (jit_matrix *mat, jit_range *index, | |
301 double value) | |
300 { | 302 { |
301 NDArray *array = mat->array; | 303 NDArray *array = mat->array; |
302 bool done = false; | 304 bool done = false; |
303 | 305 |
304 // optimize for the simple case (no resizing and no errors) | 306 // optimize for the simple case (no resizing and no errors) |
338 NDArray avalue (dim_vector (1, 1)); | 340 NDArray avalue (dim_vector (1, 1)); |
339 avalue.xelem (0) = value; | 341 avalue.xelem (0) = value; |
340 array->assign (idx, avalue); | 342 array->assign (idx, avalue); |
341 } | 343 } |
342 | 344 |
343 result->update (array); | 345 jit_matrix ret; |
346 ret.update (array); | |
347 return ret; | |
344 } | 348 } |
345 | 349 |
346 extern "C" double | 350 extern "C" double |
347 octave_jit_end_matrix (jit_matrix *mat, octave_idx_type idx, | 351 octave_jit_end_matrix (jit_matrix *mat, octave_idx_type idx, |
348 octave_idx_type count) | 352 octave_idx_type count) |
560 // we mark all functinos as external linkage because this prevents llvm | 564 // we mark all functinos as external linkage because this prevents llvm |
561 // from getting rid of always inline functions | 565 // from getting rid of always inline functions |
562 llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false); | 566 llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false); |
563 llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, | 567 llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, |
564 aname, module); | 568 aname, module); |
569 | |
570 if (sret ()) | |
571 llvm_function->addAttribute (1, llvm::Attribute::StructRet); | |
572 | |
565 if (call_conv == jit_convention::internal) | 573 if (call_conv == jit_convention::internal) |
566 llvm_function->addFnAttr (llvm::Attribute::AlwaysInline); | 574 llvm_function->addFnAttr (llvm::Attribute::AlwaysInline); |
567 } | 575 } |
568 | 576 |
569 jit_function::jit_function (const jit_function& fn, jit_type *aresult, | 577 jit_function::jit_function (const jit_function& fn, jit_type *aresult, |
618 llvm::Function *stacksave | 626 llvm::Function *stacksave |
619 = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); | 627 = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); |
620 llvm::SmallVector<llvm::Value *, 10> llvm_args; | 628 llvm::SmallVector<llvm::Value *, 10> llvm_args; |
621 llvm_args.reserve (in_args.size () + sret ()); | 629 llvm_args.reserve (in_args.size () + sret ()); |
622 | 630 |
623 llvm::Value *sret_mem = 0; | 631 llvm::BasicBlock *insert_block = builder.GetInsertBlock (); |
624 llvm::Value *saved_stack = 0; | 632 llvm::Function *parent = insert_block->getParent (); |
633 assert (parent); | |
634 | |
635 // we insert allocas inside the prelude block to prevent stack overflows | |
636 llvm::BasicBlock& prelude = parent->getEntryBlock (); | |
637 llvm::IRBuilder<> pre_builder (&prelude, prelude.begin ()); | |
638 | |
639 llvm::AllocaInst *sret_mem = 0; | |
625 if (sret ()) | 640 if (sret ()) |
626 { | 641 { |
627 saved_stack = builder.CreateCall (stacksave); | 642 sret_mem = pre_builder.CreateAlloca (mresult->packed_type (call_conv)); |
628 sret_mem = builder.CreateAlloca (mresult->packed_type (call_conv)); | |
629 llvm_args.push_back (sret_mem); | 643 llvm_args.push_back (sret_mem); |
630 } | 644 } |
631 | 645 |
632 for (size_t i = 0; i < in_args.size (); ++i) | 646 for (size_t i = 0; i < in_args.size (); ++i) |
633 { | 647 { |
636 if (convert) | 650 if (convert) |
637 arg = convert (builder, arg); | 651 arg = convert (builder, arg); |
638 | 652 |
639 if (args[i]->pointer_arg (call_conv)) | 653 if (args[i]->pointer_arg (call_conv)) |
640 { | 654 { |
641 if (! saved_stack) | 655 llvm::Type *ty = args[i]->packed_type (call_conv); |
642 saved_stack = builder.CreateCall (stacksave); | 656 llvm::Value *alloca = pre_builder.CreateAlloca (ty); |
643 | 657 builder.CreateStore (arg, alloca); |
644 arg = builder.CreateAlloca (args[i]->to_llvm ()); | 658 arg = alloca; |
645 builder.CreateStore (in_args[i], arg); | |
646 } | 659 } |
647 | 660 |
648 llvm_args.push_back (arg); | 661 llvm_args.push_back (arg); |
649 } | 662 } |
650 | 663 |
651 llvm::Value *ret = builder.CreateCall (llvm_function, llvm_args); | 664 llvm::CallInst *callinst = builder.CreateCall (llvm_function, llvm_args); |
652 if (sret_mem) | 665 llvm::Value *ret = callinst; |
653 ret = builder.CreateLoad (sret_mem); | 666 |
667 if (sret ()) | |
668 { | |
669 callinst->addAttribute (1, llvm::Attribute::StructRet); | |
670 ret = builder.CreateLoad (sret_mem); | |
671 } | |
654 | 672 |
655 if (mresult) | 673 if (mresult) |
656 { | 674 { |
657 jit_type::convert_fn unpack = mresult->unpack (call_conv); | 675 jit_type::convert_fn unpack = mresult->unpack (call_conv); |
658 if (unpack) | 676 if (unpack) |
659 ret = unpack (builder, ret); | 677 ret = unpack (builder, ret); |
660 } | |
661 | |
662 if (saved_stack) | |
663 { | |
664 llvm::Function *stackrestore | |
665 = llvm::Intrinsic::getDeclaration (module, | |
666 llvm::Intrinsic::stackrestore); | |
667 builder.CreateCall (stackrestore, saved_stack); | |
668 } | 678 } |
669 | 679 |
670 return ret; | 680 return ret; |
671 } | 681 } |
672 | 682 |
689 | 699 |
690 return iter; | 700 return iter; |
691 } | 701 } |
692 | 702 |
693 void | 703 void |
694 jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval) | 704 jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval, |
705 bool verify) | |
695 { | 706 { |
696 assert (! rval == ! mresult); | 707 assert (! rval == ! mresult); |
697 | 708 |
698 if (rval) | 709 if (rval) |
699 { | 710 { |
700 jit_type::convert_fn convert = mresult->pack (call_conv); | 711 jit_type::convert_fn convert = mresult->pack (call_conv); |
701 if (convert) | 712 if (convert) |
702 rval = convert (builder, rval); | 713 rval = convert (builder, rval); |
703 | 714 |
704 if (sret ()) | 715 if (sret ()) |
705 builder.CreateStore (rval, llvm_function->arg_begin ()); | 716 { |
717 builder.CreateStore (rval, llvm_function->arg_begin ()); | |
718 builder.CreateRetVoid (); | |
719 } | |
706 else | 720 else |
707 builder.CreateRet (rval); | 721 builder.CreateRet (rval); |
708 } | 722 } |
709 else | 723 else |
710 builder.CreateRetVoid (); | 724 builder.CreateRetVoid (); |
711 | 725 |
712 llvm::verifyFunction (*llvm_function); | 726 if (verify) |
727 llvm::verifyFunction (*llvm_function); | |
713 } | 728 } |
714 | 729 |
715 void | 730 void |
716 jit_function::do_add_mapping (llvm::ExecutionEngine *engine, void *fn) | 731 jit_function::do_add_mapping (llvm::ExecutionEngine *engine, void *fn) |
717 { | 732 { |
1030 | 1045 |
1031 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); | 1046 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); |
1032 | 1047 |
1033 // complex_ret is what is passed to C functions in order to get calling | 1048 // complex_ret is what is passed to C functions in order to get calling |
1034 // convention right | 1049 // convention right |
1050 llvm::Type *cmplx_inner_cont[] = {scalar_t, scalar_t}; | |
1051 llvm::StructType *cmplx_inner = llvm::StructType::create (cmplx_inner_cont); | |
1052 | |
1035 complex_ret = llvm::StructType::create (context, "complex_ret"); | 1053 complex_ret = llvm::StructType::create (context, "complex_ret"); |
1036 llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; | 1054 { |
1037 complex_ret->setBody (complex_ret_contents); | 1055 llvm::Type *contents[] = {cmplx_inner}; |
1056 complex_ret->setBody (contents); | |
1057 } | |
1038 | 1058 |
1039 // create types | 1059 // create types |
1040 any = new_type ("any", 0, any_t); | 1060 any = new_type ("any", 0, any_t); |
1041 matrix = new_type ("matrix", any, matrix_t); | 1061 matrix = new_type ("matrix", any, matrix_t); |
1042 complex = new_type ("complex", any, complex_t); | 1062 complex = new_type ("complex", any, complex_t); |
1057 identities.resize (next_id + 1); | 1077 identities.resize (next_id + 1); |
1058 | 1078 |
1059 // specify calling conventions | 1079 // specify calling conventions |
1060 // FIXME: We should detect architecture and do something sane based on that | 1080 // FIXME: We should detect architecture and do something sane based on that |
1061 // here we assume x86 or x86_64 | 1081 // here we assume x86 or x86_64 |
1062 matrix->mark_sret (); | 1082 matrix->mark_sret (jit_convention::external); |
1063 matrix->mark_pointer_arg (); | 1083 matrix->mark_pointer_arg (jit_convention::external); |
1064 | 1084 |
1065 range->mark_sret (); | 1085 range->mark_sret (jit_convention::external); |
1066 range->mark_pointer_arg (); | 1086 range->mark_pointer_arg (jit_convention::external); |
1067 | 1087 |
1068 complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex); | 1088 complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex); |
1069 complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex); | 1089 complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex); |
1070 complex->set_packed_type (jit_convention::external, complex_ret); | 1090 complex->set_packed_type (jit_convention::external, complex_ret); |
1071 | 1091 |
1072 if (sizeof (void *) == 4) | 1092 if (sizeof (void *) == 4) |
1073 complex->mark_sret (); | 1093 complex->mark_sret (jit_convention::external); |
1074 | 1094 |
1075 paren_subsref_fn.initialize (module, engine); | 1095 paren_subsref_fn.initialize (module, engine); |
1076 paren_subsasgn_fn.initialize (module, engine); | 1096 paren_subsasgn_fn.initialize (module, engine); |
1077 | 1097 |
1078 // bind global variables | 1098 // bind global variables |
1331 complex_div.add_mapping (engine, &octave_jit_complex_div); | 1351 complex_div.add_mapping (engine, &octave_jit_complex_div); |
1332 complex_div.mark_can_error (); | 1352 complex_div.mark_can_error (); |
1333 binary_ops[octave_value::op_div].add_overload (fn); | 1353 binary_ops[octave_value::op_div].add_overload (fn); |
1334 binary_ops[octave_value::op_ldiv].add_overload (fn); | 1354 binary_ops[octave_value::op_ldiv].add_overload (fn); |
1335 | 1355 |
1336 fn = mirror_binary (complex_div); | 1356 // fn = mirror_binary (complex_div); |
1337 binary_ops[octave_value::op_ldiv].add_overload (fn); | 1357 // binary_ops[octave_value::op_ldiv].add_overload (fn); |
1338 binary_ops[octave_value::op_el_ldiv].add_overload (fn); | 1358 // binary_ops[octave_value::op_el_ldiv].add_overload (fn); |
1339 | 1359 |
1340 fn = create_function (jit_convention::external, | 1360 fn = create_function (jit_convention::external, |
1341 "octave_jit_pow_complex_complex", complex, complex, | 1361 "octave_jit_pow_complex_complex", complex, complex, |
1342 complex); | 1362 complex); |
1343 fn.add_mapping (engine, &octave_jit_pow_complex_complex); | 1363 fn.add_mapping (engine, &octave_jit_pow_complex_complex); |
1988 if (id >= identities.size ()) | 2008 if (id >= identities.size ()) |
1989 identities.resize (id + 1); | 2009 identities.resize (id + 1); |
1990 | 2010 |
1991 if (! identities[id].valid ()) | 2011 if (! identities[id].valid ()) |
1992 { | 2012 { |
1993 jit_function fn = create_function (jit_convention::internal, "id", type, | 2013 std::stringstream name; |
1994 type); | 2014 name << "id_" << type->name (); |
2015 jit_function fn = create_function (jit_convention::internal, name.str (), | |
2016 type, type); | |
2017 | |
1995 llvm::BasicBlock *body = fn.new_block (); | 2018 llvm::BasicBlock *body = fn.new_block (); |
1996 builder.SetInsertPoint (body); | 2019 builder.SetInsertPoint (body); |
1997 fn.do_return (builder, fn.argument (builder, 0)); | 2020 fn.do_return (builder, fn.argument (builder, 0)); |
1998 return identities[id] = fn; | 2021 return identities[id] = fn; |
1999 } | 2022 } |
2139 { | 2162 { |
2140 llvm::Type *complex_ret = instance->complex_ret; | 2163 llvm::Type *complex_ret = instance->complex_ret; |
2141 llvm::Value *real = bld.CreateExtractElement (cplx, bld.getInt32 (0)); | 2164 llvm::Value *real = bld.CreateExtractElement (cplx, bld.getInt32 (0)); |
2142 llvm::Value *imag = bld.CreateExtractElement (cplx, bld.getInt32 (1)); | 2165 llvm::Value *imag = bld.CreateExtractElement (cplx, bld.getInt32 (1)); |
2143 llvm::Value *ret = llvm::UndefValue::get (complex_ret); | 2166 llvm::Value *ret = llvm::UndefValue::get (complex_ret); |
2144 ret = bld.CreateInsertValue (ret, real, 0); | 2167 |
2145 return bld.CreateInsertValue (ret, imag, 1); | 2168 unsigned int re_idx[] = {0, 0}; |
2169 unsigned int im_idx[] = {0, 1}; | |
2170 ret = bld.CreateInsertValue (ret, real, re_idx); | |
2171 return bld.CreateInsertValue (ret, imag, im_idx); | |
2146 } | 2172 } |
2147 | 2173 |
2148 llvm::Value * | 2174 llvm::Value * |
2149 jit_typeinfo::unpack_complex (llvm::IRBuilderD& bld, llvm::Value *result) | 2175 jit_typeinfo::unpack_complex (llvm::IRBuilderD& bld, llvm::Value *result) |
2150 { | 2176 { |
2177 unsigned int re_idx[] = {0, 0}; | |
2178 unsigned int im_idx[] = {0, 1}; | |
2179 | |
2151 llvm::Type *complex_t = get_complex ()->to_llvm (); | 2180 llvm::Type *complex_t = get_complex ()->to_llvm (); |
2152 llvm::Value *real = bld.CreateExtractValue (result, 0); | 2181 llvm::Value *real = bld.CreateExtractValue (result, re_idx); |
2153 llvm::Value *imag = bld.CreateExtractValue (result, 1); | 2182 llvm::Value *imag = bld.CreateExtractValue (result, im_idx); |
2154 llvm::Value *ret = llvm::UndefValue::get (complex_t); | 2183 llvm::Value *ret = llvm::UndefValue::get (complex_t); |
2184 | |
2155 ret = bld.CreateInsertElement (ret, real, bld.getInt32 (0)); | 2185 ret = bld.CreateInsertElement (ret, real, bld.getInt32 (0)); |
2156 return bld.CreateInsertElement (ret, imag, bld.getInt32 (1)); | 2186 return bld.CreateInsertElement (ret, imag, bld.getInt32 (1)); |
2157 } | 2187 } |
2158 | 2188 |
2159 llvm::Value * | 2189 llvm::Value * |