# HG changeset patch # User Max Brister # Date 1343671529 18000 # Node ID bc32288f4a42a21ee0a869e15f37eaa5447b1934 # Parent a6d4965ef04bb6ebce4c74e9df55f8cb26991d7e Support the end keyword for one dimentional indexing in JIT. * src/jit-ir.cc (jit_magic_end): New class. * src/jit-ir.h (jit_magic_end): New class. (jit_instruction::jit_instruction): New overload. * src/jit-typeinfo.cc (jit_function::call): Throw jit_fail_exception if invalid. (jit_typeinfo::jit_typeinfo): Initialize end_fn. * src/jit-typeinfo.h (jit_typeinfo::end): New function. * src/pt-jit.cc (jit_convert::visit_identifier): Handle magic_end. (jit_convert::resolve): Keep track of end context. (jit_convert::convert_llvm::visit): New overload. * src/pt-jit.h (jit_convert): Add end_context. diff -r a6d4965ef04b -r bc32288f4a42 src/jit-ir.cc --- a/src/jit-ir.cc Mon Jul 30 16:23:52 2012 +0100 +++ b/src/jit-ir.cc Mon Jul 30 13:05:29 2012 -0500 @@ -598,4 +598,36 @@ return false; } +// -------------------- jit_magic_end -------------------- +const jit_function& +jit_magic_end::overload () const +{ + jit_value *ctx = resolve_context (); + if (ctx) + return jit_typeinfo::end (ctx->type ()); + + static jit_function null_ret; + return null_ret; +} + +jit_value * +jit_magic_end::resolve_context (void) const +{ + // FIXME: We need to have a way of marking functions so we can skip them here + return argument_count () ? argument (0) : 0; +} + +bool +jit_magic_end::infer (void) +{ + jit_type *new_type = overload ().result (); + if (new_type != type ()) + { + stash_type (new_type); + return true; + } + + return false; +} + #endif diff -r a6d4965ef04b -r bc32288f4a42 src/jit-ir.h --- a/src/jit-ir.h Mon Jul 30 16:23:52 2012 +0100 +++ b/src/jit-ir.h Mon Jul 30 13:05:29 2012 -0500 @@ -46,7 +46,8 @@ JIT_METH(variable); \ JIT_METH(error_check); \ JIT_METH(assign) \ - JIT_METH(argument) + JIT_METH(argument) \ + JIT_METH(magic_end) #define JIT_VISIT_IR_CONST \ JIT_METH(const_bool); \ @@ -256,6 +257,14 @@ #undef STASH_ARG #undef JIT_INSTRUCTION_CTOR + jit_instruction (const std::vector& aarguments) + : already_infered (aarguments.size ()), marguments (aarguments.size ()), + mid (next_id ()), mparent (0) + { + for (size_t i = 0; i < aarguments.size (); ++i) + stash_argument (i, aarguments[i]); + } + static void reset_ids (void) { next_id (true); @@ -1137,6 +1146,34 @@ } }; +// for now only handles the 1D case +class +jit_magic_end : public jit_instruction +{ +public: + jit_magic_end (const std::vector& context) + : jit_instruction (context) + {} + + const jit_function& overload () const; + + jit_value *resolve_context (void) const; + + virtual bool infer (void); + + virtual std::ostream& short_print (std::ostream& os) const + { + return os << "magic_end"; + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + return short_print (print_indent (os, indent)); + } + + JIT_VALUE_ACCEPT; +}; + class jit_extract_argument : public jit_assign_base { diff -r a6d4965ef04b -r bc32288f4a42 src/jit-typeinfo.cc --- a/src/jit-typeinfo.cc Mon Jul 30 16:23:52 2012 +0100 +++ b/src/jit-typeinfo.cc Mon Jul 30 13:05:29 2012 -0500 @@ -522,8 +522,10 @@ jit_function::call (llvm::IRBuilderD& builder, const std::vector& in_args) const { + if (! valid ()) + throw jit_fail_exception ("Call not implemented"); + assert (in_args.size () == args.size ()); - std::vector llvm_args (args.size ()); for (size_t i = 0; i < in_args.size (); ++i) llvm_args[i] = in_args[i]->to_llvm (); @@ -535,7 +537,9 @@ jit_function::call (llvm::IRBuilderD& builder, const std::vector& in_args) const { - assert (valid ()); + if (! valid ()) + throw jit_fail_exception ("Call not implemented"); + assert (in_args.size () == args.size ()); llvm::Function *stacksave = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); @@ -1342,8 +1346,7 @@ builder.CreateBr (done); builder.SetInsertPoint (normal); - llvm::Value *len = builder.CreateExtractValue (mat, - llvm::ArrayRef (2)); + llvm::Value *len = builder.CreateExtractValue (mat, 2); cond0 = builder.CreateICmpSGT (int_idx, len); llvm::Value *rcount = builder.CreateExtractValue (mat, 0); @@ -1386,6 +1389,18 @@ fn.mark_can_error (); paren_subsasgn_fn.add_overload (fn); + end_fn.stash_name ("end"); + fn = create_function (jit_convention::internal, "octave_jit_end_matrix", + scalar, matrix); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *mat = fn.argument (builder, 0); + llvm::Value *ret = builder.CreateExtractValue (mat, 2); + fn.do_return (builder, builder.CreateSIToFP (ret, scalar_t)); + } + end_fn.add_overload (fn); + casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); casts[complex->type_id ()].stash_name ("(complex)"); diff -r a6d4965ef04b -r bc32288f4a42 src/jit-typeinfo.h --- a/src/jit-typeinfo.h Mon Jul 30 16:23:52 2012 +0100 +++ b/src/jit-typeinfo.h Mon Jul 30 13:05:29 2012 -0500 @@ -471,6 +471,16 @@ { return instance->do_insert_error_check (bld); } + + static const jit_operation& end (void) + { + return instance->end_fn; + } + + static const jit_function& end (jit_type *ty) + { + return instance->end_fn.overload (ty); + } private: jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); @@ -655,6 +665,7 @@ jit_operation make_range_fn; jit_operation paren_subsref_fn; jit_operation paren_subsasgn_fn; + jit_operation end_fn; // type id -> cast function TO that type std::vector casts; diff -r a6d4965ef04b -r bc32288f4a42 src/pt-jit.cc --- a/src/pt-jit.cc Mon Jul 30 16:23:52 2012 +0100 +++ b/src/pt-jit.cc Mon Jul 30 13:05:29 2012 -0500 @@ -412,7 +412,14 @@ void jit_convert::visit_identifier (tree_identifier& ti) { - result = get_variable (ti.name ()); + if (ti.has_magic_end ()) + { + if (!end_context.size ()) + throw jit_fail_exception ("Illegal end"); + result = block->append (create (end_context)); + } + else + result = get_variable (ti.name ()); } void @@ -826,6 +833,12 @@ tree_expression *tree_object = exp.expression (); jit_value *object = visit (tree_object); + + end_context.push_back (object); + + unwind_protect prot; + prot.add_method (&end_context, &std::vector::pop_back); + tree_expression *arg0 = arg_list->front (); jit_value *index = visit (arg0); @@ -1479,6 +1492,14 @@ jit_convert::convert_llvm::visit (jit_argument&) {} +void +jit_convert::convert_llvm::visit (jit_magic_end& me) +{ + const jit_function& ol = me.overload (); + llvm::Value *ret = ol.call (builder, me.resolve_context ()); + me.stash_llvm (ret); +} + // -------------------- tree_jit -------------------- tree_jit::tree_jit (void) : module (0), engine (0) @@ -1823,4 +1844,13 @@ %! endwhile %! assert (i == niter); +%!test +%! niter = 1001; +%! result = 0; +%! m = [5 10]; +%! for i=1:niter +%! result = result + m(end); +%! endfor +%! assert (result == m(end) * niter); + */ diff -r a6d4965ef04b -r bc32288f4a42 src/pt-jit.h --- a/src/pt-jit.h Mon Jul 30 16:23:52 2012 +0100 +++ b/src/pt-jit.h Mon Jul 30 13:05:29 2012 -0500 @@ -244,6 +244,8 @@ std::list all_values; + std::vector end_context; + size_t iterator_count; size_t for_bounds_count; size_t short_count;