changeset 15102:d29f2583cf7b

Support end in multi indexing in JIT * src/interp-core/jit-ir.cc (jit_magic_end::context::context): New function. (jit_magic_end::jit_magic_end): Take context vector as argument. (jit_magic_end::resolve_context): Return a context. (jit_magic_end::print): Prettify output. (jit_magic_end::overload): Use context. * src/interp-core/jit-ir.h (jit_magic_end::context::context, jit_magic_end::print): Move implementation to src/jit-ir.cc. (jit_magic_end::short_print): Prettify output. (jit_magic_end::resolve_context): Return a context. * src/interp-core/jit-typeinfo.cc (octave_jit_end_matrix): New function. (jit_typeinfo::jit_typeinfo): Initilaize end_fn and end1_fn. (jit_typeinfo::do_end): New function. (jit_typeinfo::new_type): Moved location in file. * src/interp-core/jit-typeinfo.h (jit_typeinfo::end): Take index and count arguments. (jit_typeinfo::do_end): New declaration. * src/interp-core/pt-jit.cc (jit_convert::resolve): Pass extra argument to context constructor. (jit_convert::convert_llvm::visit): New arguments to jit_magic_end overload.
author Max Brister <max@2bass.com>
date Sat, 04 Aug 2012 00:19:07 -0500
parents 2512448babac
children 03381a36f70d
files src/interp-core/jit-ir.cc src/interp-core/jit-ir.h src/interp-core/jit-typeinfo.cc src/interp-core/jit-typeinfo.h src/interp-core/pt-jit.cc
diffstat 5 files changed, 113 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/src/interp-core/jit-ir.cc	Fri Aug 03 17:38:05 2012 -0700
+++ b/src/interp-core/jit-ir.cc	Sat Aug 04 00:19:07 2012 -0500
@@ -599,38 +599,30 @@
 }
 
 // -------------------- jit_magic_end --------------------
+jit_magic_end::context::context (jit_convert& convert, jit_value *avalue,
+                                 size_t aindex, size_t acount)
+  : value (avalue), index (convert.create<jit_const_index> (aindex)),
+    count (convert.create<jit_const_index> (acount))
+{}
+
 jit_magic_end::jit_magic_end (const std::vector<context>& full_context)
+  : contexts (full_context)
 {
-  // for now we only support end in 1 dimensional indexing
-  resize_arguments (full_context.size ());
+  resize_arguments (contexts.size ());
 
   size_t i;
   std::vector<context>::const_iterator iter;
-  for (iter = full_context.begin (), i = 0; iter != full_context.end (); ++iter,
-         ++i)
-    {
-      if (iter->count != 1)
-        throw jit_fail_exception ("end is only supported in linear contexts");
-      stash_argument (i, iter->value);
-    }
+  for (iter = contexts.begin (), i = 0; iter != contexts.end (); ++iter, ++i)
+    stash_argument (i, iter->value);
 }
 
-const jit_function&
-jit_magic_end::overload () const
-{
-  jit_value *ctx = resolve_context ();
-  if (ctx)
-    return jit_typeinfo::end (ctx->type ());
-
-  static jit_function null_ret;
-  return null_ret;
-}
-
-jit_value *
+jit_magic_end::context
 jit_magic_end::resolve_context (void) const
 {
   // FIXME: We need to have a way of marking functions so we can skip them here
-  return argument_count () ? argument (0) : 0;
+  context ret = contexts[0];
+  ret.value = argument (0);
+  return ret;
 }
 
 bool
@@ -646,4 +638,19 @@
   return false;
 }
 
+std::ostream&
+jit_magic_end::print (std::ostream& os, size_t indent) const
+{
+  context ctx = resolve_context ();
+  short_print (print_indent (os, indent)) << " (" << *ctx.value << ", ";
+  return os << *ctx.index << ", " << *ctx.count << ")";
+}
+
+const jit_function&
+jit_magic_end::overload () const
+{
+  const context& ctx = resolve_context ();
+  return jit_typeinfo::end (ctx.value, ctx.index, ctx.count);
+}
+
 #endif
--- a/src/interp-core/jit-ir.h	Fri Aug 03 17:38:05 2012 -0700
+++ b/src/interp-core/jit-ir.h	Sat Aug 04 00:19:07 2012 -0500
@@ -1162,34 +1162,32 @@
     context (void) : value (0), index (0), count (0)
     {}
 
-    context (jit_value *avalue, size_t aindex, size_t acount)
-      : value (avalue), index (aindex), count (acount)
-    {}
+    context (jit_convert& convert, jit_value *avalue, size_t aindex,
+             size_t acount);
 
     jit_value *value;
-    size_t index;
-    size_t count;
+    jit_const_index *index;
+    jit_const_index *count;
   };
 
   jit_magic_end (const std::vector<context>& full_context);
 
+  virtual bool infer (void);
+
   const jit_function& overload () const;
 
-  jit_value *resolve_context (void) const;
+  virtual std::ostream& print (std::ostream& os, size_t indent = 0) const;
 
-  virtual bool infer (void);
+  context resolve_context (void) const;
 
   virtual std::ostream& short_print (std::ostream& os) const
   {
-    return os << "magic_end";
-  }
-
-  virtual std::ostream& print (std::ostream& os, size_t indent = 0) const
-  {
-    return short_print (print_indent (os, indent));
+    return os << "magic_end" << "#" << id ();
   }
 
   JIT_VALUE_ACCEPT;
+private:
+  std::vector<context> contexts;
 };
 
 class
--- a/src/interp-core/jit-typeinfo.cc	Fri Aug 03 17:38:05 2012 -0700
+++ b/src/interp-core/jit-typeinfo.cc	Sat Aug 04 00:19:07 2012 -0500
@@ -343,6 +343,29 @@
   result->update (array);
 }
 
+extern "C" double
+octave_jit_end_matrix (jit_matrix *mat, octave_idx_type idx,
+                       octave_idx_type count)
+{
+  octave_idx_type ndim = mat->dimensions[-1];
+  if (ndim == count)
+    return mat->dimensions[idx];
+  else if (ndim > count)
+    {
+      if (idx == count - 1)
+        {
+          double ret = mat->dimensions[idx];
+          for (octave_idx_type i = idx + 1; i < ndim; ++i)
+            ret *= mat->dimensions[idx];
+          return ret;
+        }
+
+      return mat->dimensions[idx];
+    }
+  else // ndim < count
+    return idx < ndim ? mat->dimensions[idx] : 1;
+}
+
 extern "C" Complex
 octave_jit_complex_div (Complex lhs, Complex rhs)
 {
@@ -1626,9 +1649,9 @@
   fn.mark_can_error ();
   paren_subsasgn_fn.add_overload (fn);
 
-  end_fn.stash_name ("end");
-  fn = create_function (jit_convention::internal, "octave_jit_end_matrix",
-                        scalar, matrix);
+  end1_fn.stash_name ("end1");
+  fn = create_function (jit_convention::internal, "octave_jit_end1_matrix",
+                        scalar, matrix, index, index);
   body = fn.new_block ();
   builder.SetInsertPoint (body);
   {
@@ -1636,6 +1659,11 @@
     llvm::Value *ret = builder.CreateExtractValue (mat, 2);
     fn.do_return (builder, builder.CreateSIToFP (ret, scalar_t));
   }
+  end1_fn.add_overload (fn);
+
+  end_fn.stash_name ("end");
+  fn = create_function (jit_convention::external, "octave_jit_end_matrix",
+                        scalar, matrix, index, index);
   end_fn.add_overload (fn);
 
   casts[any->type_id ()].stash_name ("(any)");
@@ -1760,6 +1788,25 @@
     }
 }
 
+const jit_function&
+jit_typeinfo::do_end (jit_value *value, jit_value *idx, jit_value *count)
+{
+  jit_const_index *ccount = dynamic_cast<jit_const_index *> (count);
+  if (ccount && ccount->value () == 1)
+    return end1_fn.overload (value->type (), idx->type (), count->type ());
+
+  return end_fn.overload (value->type (), idx->type (), count->type ());
+}
+
+jit_type*
+jit_typeinfo::new_type (const std::string& name, jit_type *parent,
+                        llvm::Type *llvm_type)
+{
+  jit_type *ret = new jit_type (name, parent, llvm_type, next_id++);
+  id_to_type.push_back (ret);
+  return ret;
+}
+
 void
 jit_typeinfo::add_print (jit_type *ty, void *fptr)
 {
@@ -2059,13 +2106,4 @@
   return get_any ();
 }
 
-jit_type*
-jit_typeinfo::new_type (const std::string& name, jit_type *parent,
-                        llvm::Type *llvm_type)
-{
-  jit_type *ret = new jit_type (name, parent, llvm_type, next_id++);
-  id_to_type.push_back (ret);
-  return ret;
-}
-
 #endif
--- a/src/interp-core/jit-typeinfo.h	Fri Aug 03 17:38:05 2012 -0700
+++ b/src/interp-core/jit-typeinfo.h	Sat Aug 04 00:19:07 2012 -0500
@@ -267,6 +267,7 @@
 
   JIT_CALL (1);
   JIT_CALL (2);
+  JIT_CALL (3);
 
 #undef JIT_CALL
 #undef JIT_PARAMS
@@ -549,9 +550,10 @@
     return instance->end_fn;
   }
 
-  static const jit_function& end (jit_type *ty)
+  static const jit_function& end (jit_value *value, jit_value *index,
+                                  jit_value *count)
   {
-    return instance->end_fn.overload (ty);
+    return instance->do_end (value, index, count);
   }
 private:
   jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e);
@@ -619,6 +621,9 @@
     return do_cast (to).overload (from);
   }
 
+  const jit_function& do_end (jit_value *value, jit_value *index,
+                              jit_value *count);
+
   jit_type *new_type (const std::string& name, jit_type *parent,
                       llvm::Type *llvm_type);
 
@@ -738,6 +743,7 @@
   jit_operation make_range_fn;
   jit_paren_subsref paren_subsref_fn;
   jit_paren_subsasgn paren_subsasgn_fn;
+  jit_operation end1_fn;
   jit_operation end_fn;
 
   // type id -> cast function TO that type
--- a/src/interp-core/pt-jit.cc	Fri Aug 03 17:38:05 2012 -0700
+++ b/src/interp-core/pt-jit.cc	Sat Aug 04 00:19:07 2012 -0500
@@ -842,7 +842,7 @@
       unwind_protect prot;
       prot.add_method (&end_context,
                        &std::vector<jit_magic_end::context>::pop_back);
-      end_context.push_back (jit_magic_end::context (object, idx, narg));
+      end_context.push_back (jit_magic_end::context (*this, object, idx, narg));
       call_args[idx + 1] = visit (*iter);
     }
 
@@ -1498,7 +1498,9 @@
 jit_convert::convert_llvm::visit (jit_magic_end& me)
 {
   const jit_function& ol = me.overload ();
-  llvm::Value *ret = ol.call (builder, me.resolve_context ());
+
+  jit_magic_end::context ctx = me.resolve_context ();
+  llvm::Value *ret = ol.call (builder, ctx.value, ctx.index, ctx.count);
   me.stash_llvm (ret);
 }
 
@@ -1927,4 +1929,15 @@
 %! endwhile
 %! assert (result, 0);
 
+%!test
+%! m = zeros (2, 1001);
+%! for i=1:1001
+%!   m(end, i) = i;
+%!   m(end - 1, end - i + 1) = i;
+%! endfor
+%! m2 = zeros (2, 1001);
+%! m2(1, :) = fliplr (1:1001);
+%! m2(2, :) = 1:1001;
+%! assert (m, m2);
+
 */