changeset 15056:bc32288f4a42

Support the end keyword for one dimentional indexing in JIT. * src/jit-ir.cc (jit_magic_end): New class. * src/jit-ir.h (jit_magic_end): New class. (jit_instruction::jit_instruction): New overload. * src/jit-typeinfo.cc (jit_function::call): Throw jit_fail_exception if invalid. (jit_typeinfo::jit_typeinfo): Initialize end_fn. * src/jit-typeinfo.h (jit_typeinfo::end): New function. * src/pt-jit.cc (jit_convert::visit_identifier): Handle magic_end. (jit_convert::resolve): Keep track of end context. (jit_convert::convert_llvm::visit): New overload. * src/pt-jit.h (jit_convert): Add end_context.
author Max Brister <max@2bass.com>
date Mon, 30 Jul 2012 13:05:29 -0500
parents a6d4965ef04b
children 46b19589b593 6130d87495b8
files src/jit-ir.cc src/jit-ir.h src/jit-typeinfo.cc src/jit-typeinfo.h src/pt-jit.cc src/pt-jit.h
diffstat 6 files changed, 133 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/src/jit-ir.cc	Mon Jul 30 16:23:52 2012 +0100
+++ b/src/jit-ir.cc	Mon Jul 30 13:05:29 2012 -0500
@@ -598,4 +598,36 @@
   return false;
 }
 
+// -------------------- jit_magic_end --------------------
+const jit_function&
+jit_magic_end::overload () const
+{
+  jit_value *ctx = resolve_context ();
+  if (ctx)
+    return jit_typeinfo::end (ctx->type ());
+
+  static jit_function null_ret;
+  return null_ret;
+}
+
+jit_value *
+jit_magic_end::resolve_context (void) const
+{
+  // FIXME: We need to have a way of marking functions so we can skip them here
+  return argument_count () ? argument (0) : 0;
+}
+
+bool
+jit_magic_end::infer (void)
+{
+  jit_type *new_type = overload ().result ();
+  if (new_type != type ())
+    {
+      stash_type (new_type);
+      return true;
+    }
+
+  return false;
+}
+
 #endif
--- a/src/jit-ir.h	Mon Jul 30 16:23:52 2012 +0100
+++ b/src/jit-ir.h	Mon Jul 30 13:05:29 2012 -0500
@@ -46,7 +46,8 @@
   JIT_METH(variable);                           \
   JIT_METH(error_check);                        \
   JIT_METH(assign)                              \
-  JIT_METH(argument)
+  JIT_METH(argument)                            \
+  JIT_METH(magic_end)
 
 #define JIT_VISIT_IR_CONST                      \
   JIT_METH(const_bool);                         \
@@ -256,6 +257,14 @@
 #undef STASH_ARG
 #undef JIT_INSTRUCTION_CTOR
 
+  jit_instruction (const std::vector<jit_value *>& aarguments)
+  : already_infered (aarguments.size ()), marguments (aarguments.size ()),
+    mid (next_id ()), mparent (0)
+  {
+    for (size_t i = 0; i < aarguments.size (); ++i)
+      stash_argument (i, aarguments[i]);
+  }
+
   static void reset_ids (void)
   {
     next_id (true);
@@ -1137,6 +1146,34 @@
   }
 };
 
+// for now only handles the 1D case
+class
+jit_magic_end : public jit_instruction
+{
+public:
+  jit_magic_end (const std::vector<jit_value *>& context)
+    : jit_instruction (context)
+  {}
+
+  const jit_function& overload () const;
+
+  jit_value *resolve_context (void) const;
+
+  virtual bool infer (void);
+
+  virtual std::ostream& short_print (std::ostream& os) const
+  {
+    return os << "magic_end";
+  }
+
+  virtual std::ostream& print (std::ostream& os, size_t indent = 0) const
+  {
+    return short_print (print_indent (os, indent));
+  }
+
+  JIT_VALUE_ACCEPT;
+};
+
 class
 jit_extract_argument : public jit_assign_base
 {
--- a/src/jit-typeinfo.cc	Mon Jul 30 16:23:52 2012 +0100
+++ b/src/jit-typeinfo.cc	Mon Jul 30 13:05:29 2012 -0500
@@ -522,8 +522,10 @@
 jit_function::call (llvm::IRBuilderD& builder,
                     const std::vector<jit_value *>& in_args) const
 {
+  if (! valid ())
+    throw jit_fail_exception ("Call not implemented");
+
   assert (in_args.size () == args.size ());
-
   std::vector<llvm::Value *> llvm_args (args.size ());
   for (size_t i = 0; i < in_args.size (); ++i)
     llvm_args[i] = in_args[i]->to_llvm ();
@@ -535,7 +537,9 @@
 jit_function::call (llvm::IRBuilderD& builder,
                     const std::vector<llvm::Value *>& in_args) const
 {
-  assert (valid ());
+  if (! valid ())
+    throw jit_fail_exception ("Call not implemented");
+
   assert (in_args.size () == args.size ());
   llvm::Function *stacksave
     = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave);
@@ -1342,8 +1346,7 @@
     builder.CreateBr (done);
 
     builder.SetInsertPoint (normal);
-    llvm::Value *len = builder.CreateExtractValue (mat,
-                                                   llvm::ArrayRef<unsigned> (2));
+    llvm::Value *len = builder.CreateExtractValue (mat, 2);
     cond0 = builder.CreateICmpSGT (int_idx, len);
 
     llvm::Value *rcount = builder.CreateExtractValue (mat, 0);
@@ -1386,6 +1389,18 @@
   fn.mark_can_error ();
   paren_subsasgn_fn.add_overload (fn);
 
+  end_fn.stash_name ("end");
+  fn = create_function (jit_convention::internal, "octave_jit_end_matrix",
+                        scalar, matrix);
+  body = fn.new_block ();
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *mat = fn.argument (builder, 0);
+    llvm::Value *ret = builder.CreateExtractValue (mat, 2);
+    fn.do_return (builder, builder.CreateSIToFP (ret, scalar_t));
+  }
+  end_fn.add_overload (fn);
+
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
   casts[complex->type_id ()].stash_name ("(complex)");
--- a/src/jit-typeinfo.h	Mon Jul 30 16:23:52 2012 +0100
+++ b/src/jit-typeinfo.h	Mon Jul 30 13:05:29 2012 -0500
@@ -471,6 +471,16 @@
   {
     return instance->do_insert_error_check (bld);
   }
+
+  static const jit_operation& end (void)
+  {
+    return instance->end_fn;
+  }
+
+  static const jit_function& end (jit_type *ty)
+  {
+    return instance->end_fn.overload (ty);
+  }
 private:
   jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e);
 
@@ -655,6 +665,7 @@
   jit_operation make_range_fn;
   jit_operation paren_subsref_fn;
   jit_operation paren_subsasgn_fn;
+  jit_operation end_fn;
 
   // type id -> cast function TO that type
   std::vector<jit_operation> casts;
--- a/src/pt-jit.cc	Mon Jul 30 16:23:52 2012 +0100
+++ b/src/pt-jit.cc	Mon Jul 30 13:05:29 2012 -0500
@@ -412,7 +412,14 @@
 void
 jit_convert::visit_identifier (tree_identifier& ti)
 {
-  result = get_variable (ti.name ());
+  if (ti.has_magic_end ())
+    {
+      if (!end_context.size ())
+        throw jit_fail_exception ("Illegal end");
+      result = block->append (create<jit_magic_end> (end_context));
+    }
+  else
+    result = get_variable (ti.name ());
 }
 
 void
@@ -826,6 +833,12 @@
 
   tree_expression *tree_object = exp.expression ();
   jit_value *object = visit (tree_object);
+
+  end_context.push_back (object);
+
+  unwind_protect prot;
+  prot.add_method (&end_context, &std::vector<jit_value *>::pop_back);
+
   tree_expression *arg0 = arg_list->front ();
   jit_value *index = visit (arg0);
 
@@ -1479,6 +1492,14 @@
 jit_convert::convert_llvm::visit (jit_argument&)
 {}
 
+void
+jit_convert::convert_llvm::visit (jit_magic_end& me)
+{
+  const jit_function& ol = me.overload ();
+  llvm::Value *ret = ol.call (builder, me.resolve_context ());
+  me.stash_llvm (ret);
+}
+
 // -------------------- tree_jit --------------------
 
 tree_jit::tree_jit (void) : module (0), engine (0)
@@ -1823,4 +1844,13 @@
 %! endwhile
 %! assert (i == niter);
 
+%!test
+%! niter = 1001;
+%! result = 0;
+%! m = [5 10];
+%! for i=1:niter
+%!   result = result + m(end);
+%! endfor
+%! assert (result == m(end) * niter);
+
 */
--- a/src/pt-jit.h	Mon Jul 30 16:23:52 2012 +0100
+++ b/src/pt-jit.h	Mon Jul 30 13:05:29 2012 -0500
@@ -244,6 +244,8 @@
 
   std::list<jit_value *> all_values;
 
+  std::vector<jit_value *> end_context;
+
   size_t iterator_count;
   size_t for_bounds_count;
   size_t short_count;