# HG changeset patch # User Max Brister # Date 1344545159 18000 # Node ID bd6bb87e2bea623229b3011776b9b395e65505c1 # Parent edae65062740af4158e64a0e2328f0f56c05ad11 Support sin, cos, and exp with matrix arguments in JIT * src/interp-core/jit-typeinfo.cc (jit_operation::generate): Remove unused parameter name. (jit_typeinfo::jit_typeinfo): Create any_call function. (jit_typeinfo::register_generic): Implement. * src/interp-core/jit-typeinfo.h (jit_typeinfo): New field, any_call. * src/interp-core/pt-jit.cc: New test. diff -r edae65062740 -r bd6bb87e2bea src/interp-core/jit-typeinfo.cc --- a/src/interp-core/jit-typeinfo.cc Thu Aug 09 08:29:50 2012 -0700 +++ b/src/interp-core/jit-typeinfo.cc Thu Aug 09 15:45:59 2012 -0500 @@ -837,7 +837,7 @@ } jit_function * -jit_operation::generate (const signature_vec& types) const +jit_operation::generate (const signature_vec&) const { return 0; } @@ -1041,6 +1041,7 @@ complex = new_type ("complex", any, complex_t); scalar = new_type ("scalar", complex, scalar_t); scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); + any_ptr = new_type ("any_ptr", 0, any_t->getPointerTo ()); range = new_type ("range", any, range_t); string = new_type ("string", any, string_t); boolean = new_type ("bool", any, bool_t); @@ -1080,6 +1081,14 @@ engine->addGlobalMapping (lerror_state, reinterpret_cast (&error_state)); + // generic call function + { + jit_type *int_t = intN (sizeof (octave_builtin::fcn) * 8); + any_call = create_function (jit_convention::external, "octave_jit_call", + any, int_t, int_t, any_ptr, int_t); + any_call.add_mapping (engine, &octave_jit_call); + } + // any with anything is an any op jit_function fn; jit_type *binary_op_type = intN (sizeof (octave_value::binary_op) * 8); @@ -1974,10 +1983,48 @@ } void -jit_typeinfo::register_generic (const std::string&, jit_type *, - const std::vector&) +jit_typeinfo::register_generic (const std::string& name, jit_type *result, + const std::vector& args) { - // FIXME: Implement + octave_builtin *builtin = find_builtin (name); + if (! builtin) + return; + + std::vector fn_args (args.size () + 1); + fn_args[0] = builtins[name]; + std::copy (args.begin (), args.end (), fn_args.begin () + 1); + jit_function fn = create_function (jit_convention::internal, name, result, + fn_args); + llvm::BasicBlock *block = fn.new_block (); + builder.SetInsertPoint (block); + llvm::Type *any_t = any->to_llvm (); + llvm::ArrayType *array_t = llvm::ArrayType::get (any_t, args.size ()); + llvm::Value *array = llvm::UndefValue::get (array_t); + for (size_t i = 0; i < args.size (); ++i) + { + llvm::Value *arg = fn.argument (builder, i + 1); + jit_function agrab = get_grab (args[i]); + llvm::Value *garg = agrab.call (builder, arg); + jit_function acast = cast (any, args[i]); + array = builder.CreateInsertValue (array, acast.call (builder, garg), i); + } + + llvm::Value *array_mem = builder.CreateAlloca (array_t); + builder.CreateStore (array, array_mem); + array = builder.CreateBitCast (array_mem, any_t->getPointerTo ()); + + jit_type *jintTy = intN (sizeof (octave_builtin::fcn) * 8); + llvm::Type *intTy = jintTy->to_llvm (); + size_t fcn_int = reinterpret_cast (builtin->function ()); + llvm::Value *fcn = llvm::ConstantInt::get (intTy, fcn_int); + llvm::Value *nargin = llvm::ConstantInt::get (intTy, args.size ()); + size_t result_int = reinterpret_cast (result); + llvm::Value *res_llvm = llvm::ConstantInt::get (intTy, result_int); + llvm::Value *ret = any_call.call (builder, fcn, nargin, array, res_llvm); + + jit_function cast_result = cast (result, any); + fn.do_return (builder, cast_result.call (builder, ret)); + paren_subsref_fn.add_overload (fn); } jit_function diff -r edae65062740 -r bd6bb87e2bea src/interp-core/jit-typeinfo.h --- a/src/interp-core/jit-typeinfo.h Thu Aug 09 08:29:50 2012 -0700 +++ b/src/interp-core/jit-typeinfo.h Thu Aug 09 15:45:59 2012 -0500 @@ -724,6 +724,7 @@ jit_type *matrix; jit_type *scalar; jit_type *scalar_ptr; // a fake type for interfacing with C++ + jit_type *any_ptr; // a fake type for interfacing with C++ jit_type *range; jit_type *string; jit_type *boolean; @@ -749,6 +750,8 @@ jit_operation end1_fn; jit_operation end_fn; + jit_function any_call; + // type id -> cast function TO that type std::vector casts; diff -r edae65062740 -r bd6bb87e2bea src/interp-core/pt-jit.cc --- a/src/interp-core/pt-jit.cc Thu Aug 09 08:29:50 2012 -0700 +++ b/src/interp-core/pt-jit.cc Thu Aug 09 15:45:59 2012 -0500 @@ -1940,4 +1940,12 @@ %! m2(2, :) = 1:1001; %! assert (m, m2); +%!test +%! m = [1 2 3]; +%! for i=1:1001 +%! m = sin (m); +%! break; +%! endfor +%! assert (m == sin ([1 2 3])); + */