Mercurial > octave-nkf
diff src/jit-typeinfo.cc @ 15078:fe4752f772e2
Generate ND indexing functions on demand in JIT.
* src/jit-typeinfo.cc (jit_operation::~jit_operation,
jit_operation::do_generate, jit_operation::generate,
jit_operation::signature_cmp::operator()): New function.
(jit_operation::overload): Call do_generate when lookup fails.
(jit_index_operation, jit_paren_subsref, jit_paren_subsasgn): New class.
(jit_typeinfo::jit_typeinfo): Update to use jit_paren_subsref and
jit_paren_subsasgn.
(jit_typeinfo::gen_subsref, jit_typeinfo::gen_subsasgn): Removed functions.
* src/jit-typeinfo.h (jit_operation::~jit_operation, jit_operation::generate,
jit_operation::do_generate): New declaration.
(jit_operation::add_overload, jit_operation::overload, jit_operation::result,
jit_operation::to_idx): Use signature_vec typedef.
(jit_operation::singature_cmp): New class.
(jit_index_operation, jit_paren_subsref, jit_paren_subsasgn): New class.
(jit_typeinfo::get_scalar_ptr): Nwe function.
(jit_typeinfo::gen_subsref, jit_typeinfo::gen_subsasgn): Removed declaration.
* src/pt-jit.cc: New test.
author | Max Brister <max@2bass.com> |
---|---|
date | Wed, 01 Aug 2012 17:00:12 -0500 |
parents | f57d7578c1a6 |
children | 9df70a18aa27 |
line wrap: on
line diff
--- a/src/jit-typeinfo.cc Wed Aug 01 12:10:26 2012 -0400 +++ b/src/jit-typeinfo.cc Wed Aug 01 17:00:12 2012 -0500 @@ -708,6 +708,16 @@ } // -------------------- jit_operation -------------------- +jit_operation::~jit_operation (void) +{ + for (generated_map::iterator iter = generated.begin (); + iter != generated.end (); ++iter) + { + delete iter->first; + delete iter->second; + } +} + void jit_operation::add_overload (const jit_function& func, const std::vector<jit_type*>& args) @@ -742,23 +752,26 @@ const jit_function& jit_operation::overload (const std::vector<jit_type*>& types) const { - // FIXME: We should search for the next best overload on failure static jit_function null_overload; - if (types.size () >= overloads.size ()) - return null_overload; - for (size_t i =0; i < types.size (); ++i) if (! types[i]) return null_overload; + if (types.size () >= overloads.size ()) + return do_generate (types); + const Array<jit_function>& over = overloads[types.size ()]; dim_vector dv (over.dims ()); Array<octave_idx_type> idx = to_idx (types); for (octave_idx_type i = 0; i < dv.length (); ++i) if (idx(i) >= dv(i)) - return null_overload; + return do_generate (types); - return over(idx); + const jit_function& ret = over(idx); + if (! ret.valid ()) + return do_generate (types); + + return ret; } Array<octave_idx_type> @@ -782,6 +795,175 @@ return idx; } +const jit_function& +jit_operation::do_generate (const signature_vec& types) const +{ + static jit_function null_overload; + generated_map::const_iterator find = generated.find (&types); + if (find != generated.end ()) + { + if (find->second) + return *find->second; + else + return null_overload; + } + + jit_function *ret = generate (types); + generated[new signature_vec (types)] = ret; + return ret ? *ret : null_overload; +} + +jit_function * +jit_operation::generate (const signature_vec& types) const +{ + return 0; +} + +bool +jit_operation::signature_cmp +::operator() (const signature_vec *lhs, const signature_vec *rhs) +{ + const signature_vec& l = *lhs; + const signature_vec& r = *rhs; + + if (l.size () < r.size ()) + return true; + else if (l.size () > r.size ()) + return false; + + for (size_t i = 0; i < l.size (); ++i) + { + if (l[i]->type_id () < r[i]->type_id ()) + return true; + else if (l[i]->type_id () > r[i]->type_id ()) + return false; + } + + return false; +} + +// -------------------- jit_index_operation -------------------- +jit_function * +jit_index_operation::generate (const signature_vec& types) const +{ + if (types.size () > 2 && types[0] == jit_typeinfo::get_matrix ()) + { + // indexing a matrix with scalars + jit_type *scalar = jit_typeinfo::get_scalar (); + for (size_t i = 1; i < types.size (); ++i) + if (types[i] != scalar) + return 0; + + return generate_matrix (types); + } + + return 0; +} + +llvm::Value * +jit_index_operation::create_arg_array (llvm::IRBuilderD& builder, + const jit_function &fn, size_t start_idx, + size_t end_idx) const +{ + size_t n = end_idx - start_idx; + llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm (); + llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); + llvm::Value *array = llvm::UndefValue::get (array_t); + for (size_t i = start_idx; i < end_idx; ++i) + { + llvm::Value *idx = fn.argument (builder, i); + array = builder.CreateInsertValue (array, idx, i - start_idx); + } + + llvm::Value *array_mem = builder.CreateAlloca (array_t); + builder.CreateStore (array, array_mem); + return builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); +} + +// -------------------- jit_paren_subsref -------------------- +jit_function * +jit_paren_subsref::generate_matrix (const signature_vec& types) const +{ + std::stringstream ss; + ss << "jit_paren_subsref_matrix_scalar" << (types.size () - 1); + + jit_type *scalar = jit_typeinfo::get_scalar (); + jit_function *fn = new jit_function (module, jit_convention::internal, + ss.str (), scalar, types); + fn->mark_can_error (); + llvm::BasicBlock *body = fn->new_block (); + llvm::IRBuilder<> builder (body); + + llvm::Value *array = create_arg_array (builder, *fn, 1, types.size ()); + jit_type *index = jit_typeinfo::get_index (); + llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), + types.size () - 1); + llvm::Value *mat = fn->argument (builder, 0); + llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem); + fn->do_return (builder, ret); + return fn; +} + +void +jit_paren_subsref::do_initialize (void) +{ + std::vector<jit_type *> types (3); + types[0] = jit_typeinfo::get_matrix (); + types[1] = jit_typeinfo::get_scalar_ptr (); + types[2] = jit_typeinfo::get_index (); + + jit_type *scalar = jit_typeinfo::get_scalar (); + paren_scalar = jit_function (module, jit_convention::external, + "octave_jit_paren_scalar", scalar, types); + paren_scalar.add_mapping (engine, &octave_jit_paren_scalar); + paren_scalar.mark_can_error (); +} + +// -------------------- jit_paren_subsasgn -------------------- +jit_function * +jit_paren_subsasgn::generate_matrix (const signature_vec& types) const +{ + std::stringstream ss; + ss << "jit_paren_subsasgn_matrix_scalar" << (types.size () - 2); + + jit_type *matrix = jit_typeinfo::get_matrix (); + jit_function *fn = new jit_function (module, jit_convention::internal, + ss.str (), matrix, types); + fn->mark_can_error (); + llvm::BasicBlock *body = fn->new_block (); + llvm::IRBuilder<> builder (body); + + llvm::Value *array = create_arg_array (builder, *fn, 1, types.size () - 1); + jit_type *index = jit_typeinfo::get_index (); + llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), + types.size () - 2); + + llvm::Value *mat = fn->argument (builder, 0); + llvm::Value *value = fn->argument (builder, types.size () - 1); + llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value); + fn->do_return (builder, ret); + return fn; +} + +void +jit_paren_subsasgn::do_initialize (void) +{ + if (paren_scalar.valid ()) + return; + + jit_type *matrix = jit_typeinfo::get_matrix (); + std::vector<jit_type *> types (4); + types[0] = matrix; + types[1] = jit_typeinfo::get_scalar_ptr (); + types[2] = jit_typeinfo::get_index (); + types[3] = jit_typeinfo::get_scalar (); + + paren_scalar = jit_function (module, jit_convention::external, + "octave_jit_paren_scalar", matrix, types); + paren_scalar.add_mapping (engine, &octave_jit_paren_scalar_subsasgn); + paren_scalar.mark_can_error (); +} + // -------------------- jit_typeinfo -------------------- void jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) @@ -835,14 +1017,12 @@ matrix = new_type ("matrix", any, matrix_t); complex = new_type ("complex", any, complex_t); scalar = new_type ("scalar", complex, scalar_t); + scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); range = new_type ("range", any, range_t); string = new_type ("string", any, string_t); boolean = new_type ("bool", any, bool_t); index = new_type ("index", any, index_t); - // a fake type for interfacing with C++ - jit_type *scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); - create_int (8); create_int (16); create_int (32); @@ -867,6 +1047,9 @@ if (sizeof (void *) == 4) complex->mark_sret (); + paren_subsref_fn.initialize (module, engine); + paren_subsasgn_fn.initialize (module, engine); + // bind global variables lerror_state = new llvm::GlobalVariable (*module, bool_t, false, llvm::GlobalValue::ExternalLinkage, @@ -1364,28 +1547,6 @@ } paren_subsref_fn.add_overload (fn); - // generate () subsref for ND indexing of matricies with scalars - jit_function paren_scalar = create_function (jit_convention::external, - "octave_jit_paren_scalar", - scalar, matrix, scalar_ptr, - index); - paren_scalar.add_mapping (engine, &octave_jit_paren_scalar); - paren_scalar.mark_can_error (); - - jit_function paren_scalar_subsasgn - = create_function (jit_convention::external, - "octave_jit_paren_scalar_subsasgn", matrix, matrix, - scalar_ptr, index, scalar); - paren_scalar_subsasgn.add_mapping (engine, &octave_jit_paren_scalar_subsasgn); - paren_scalar_subsasgn.mark_can_error (); - - // FIXME: Generate this on the fly - for (size_t i = 2; i < 10; ++i) - { - gen_subsref (paren_scalar, i); - gen_subsasgn (paren_scalar_subsasgn, i); - } - // paren subsasgn paren_subsasgn_fn.stash_name ("()subsasgn"); @@ -1907,71 +2068,4 @@ return ret; } -void -jit_typeinfo::gen_subsref (const jit_function& paren_scalar, size_t n) -{ - std::stringstream name; - name << "jit_paren_subsref_matrix_scalar" << n; - std::vector<jit_type *> args (n + 1, scalar); - args[0] = matrix; - jit_function fn = create_function (jit_convention::internal, name.str (), - scalar, args); - fn.mark_can_error (); - llvm::BasicBlock *body = fn.new_block (); - builder.SetInsertPoint (body); - - llvm::Type *scalar_t = scalar->to_llvm (); - llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); - llvm::Value *array = llvm::UndefValue::get (array_t); - for (size_t i = 0; i < n; ++i) - { - llvm::Value *idx = fn.argument (builder, i + 1); - array = builder.CreateInsertValue (array, idx, i); - } - - llvm::Value *array_mem = builder.CreateAlloca (array_t); - builder.CreateStore (array, array_mem); - array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); - - llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n); - llvm::Value *mat = fn.argument (builder, 0); - llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem); - fn.do_return (builder, ret); - paren_subsref_fn.add_overload (fn); -} - -void -jit_typeinfo::gen_subsasgn (const jit_function& paren_scalar, size_t n) -{ - std::stringstream name; - name << "jit_paren_subsasgn_matrix_scalar" << n; - std::vector<jit_type *> args (n + 2, scalar); - args[0] = matrix; - jit_function fn = create_function (jit_convention::internal, name.str (), - matrix, args); - fn.mark_can_error (); - llvm::BasicBlock *body = fn.new_block (); - builder.SetInsertPoint (body); - - llvm::Type *scalar_t = scalar->to_llvm (); - llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); - llvm::Value *array = llvm::UndefValue::get (array_t); - for (size_t i = 0; i < n; ++i) - { - llvm::Value *idx = fn.argument (builder, i + 1); - array = builder.CreateInsertValue (array, idx, i); - } - - llvm::Value *array_mem = builder.CreateAlloca (array_t); - builder.CreateStore (array, array_mem); - array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); - - llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n); - llvm::Value *mat = fn.argument (builder, 0); - llvm::Value *value = fn.argument (builder, n + 1); - llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value); - fn.do_return (builder, ret); - paren_subsasgn_fn.add_overload (fn); -} - #endif