changeset 15068:f57d7578c1a6

Support ND matrix indexing with scalar assignment in JIT. * src/jit-typeinfo.cc (make_indices, octave_jit_paren_scalar_subsasgn, jit_typeinfo::gen_subsasgn): New function. (octave_jit_paren_scalar): Use make_indices. (jit_typeinfo::jit_typeinfo): Call gen_subsasgn. * src/pt-jit.h (jit_typeinfo::gen_subsasgn): New declaration. * src/pt-jit.cc (jit_convert::resolve): Add extra_arg argument. (jit_convert::do_assign): Pass rhs to resolve. * src/pt-jit.h (jit_convert::resolve): Change function signature.
author Max Brister <max@2bass.com>
date Tue, 31 Jul 2012 15:40:52 -0500
parents df4538e3b50b
children 7a3957ca99c3
files src/jit-typeinfo.cc src/jit-typeinfo.h src/pt-jit.cc src/pt-jit.h
diffstat 4 files changed, 106 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/src/jit-typeinfo.cc	Tue Jul 31 11:51:01 2012 -0500
+++ b/src/jit-typeinfo.cc	Tue Jul 31 15:40:52 2012 -0500
@@ -243,6 +243,15 @@
   *ret = *mat;
 }
 
+static void
+make_indices (double *indices, octave_idx_type idx_count,
+              Array<idx_vector>& result)
+{
+  result.resize (dim_vector (1, idx_count));
+  for (octave_idx_type i = 0; i < idx_count; ++i)
+    result(i) = idx_vector (indices[i]);
+}
+
 extern "C" double
 octave_jit_paren_scalar (jit_matrix *mat, double *indicies,
                          octave_idx_type idx_count)
@@ -250,9 +259,8 @@
   // FIXME: Replace this with a more optimal version
   try
     {
-      Array<idx_vector> idx (dim_vector (1, idx_count));
-      for (octave_idx_type i = 0; i < idx_count; ++i)
-        idx(i) = idx_vector (indicies[i]);
+      Array<idx_vector> idx;
+      make_indices (indicies, idx_count, idx);
 
       Array<double> ret = mat->array->index (idx);
       return ret.xelem (0);
@@ -265,6 +273,28 @@
 }
 
 extern "C" void
+octave_jit_paren_scalar_subsasgn (jit_matrix *ret, jit_matrix *mat,
+                                  double *indices, octave_idx_type idx_count,
+                                  double value)
+{
+  // FIXME: Replace this with a more optimal version
+  try
+    {
+      Array<idx_vector> idx;
+      make_indices (indices, idx_count, idx);
+
+      Matrix temp (1, 1);
+      temp.xelem(0) = value;
+      mat->array->assign (idx, temp);
+      ret->update (mat->array);
+    }
+  catch (const octave_execution_exception&)
+    {
+      gripe_library_execution_error ();
+    }
+}
+
+extern "C" void
 octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat,
                                         jit_range *index, double value)
 {
@@ -1342,9 +1372,19 @@
   paren_scalar.add_mapping (engine, &octave_jit_paren_scalar);
   paren_scalar.mark_can_error ();
 
+  jit_function paren_scalar_subsasgn
+    = create_function (jit_convention::external,
+                       "octave_jit_paren_scalar_subsasgn", matrix, matrix,
+                       scalar_ptr, index, scalar);
+  paren_scalar_subsasgn.add_mapping (engine, &octave_jit_paren_scalar_subsasgn);
+  paren_scalar_subsasgn.mark_can_error ();
+
   // FIXME: Generate this on the fly
   for (size_t i = 2; i < 10; ++i)
-    gen_subsref (paren_scalar, i);
+    {
+      gen_subsref (paren_scalar, i);
+      gen_subsasgn (paren_scalar_subsasgn, i);
+    }
 
   // paren subsasgn
   paren_subsasgn_fn.stash_name ("()subsasgn");
@@ -1900,4 +1940,38 @@
   paren_subsref_fn.add_overload (fn);
 }
 
+void
+jit_typeinfo::gen_subsasgn (const jit_function& paren_scalar, size_t n)
+{
+  std::stringstream name;
+  name << "jit_paren_subsasgn_matrix_scalar" << n;
+  std::vector<jit_type *> args (n + 2, scalar);
+  args[0] = matrix;
+  jit_function fn = create_function (jit_convention::internal, name.str (),
+                                     matrix, args);
+  fn.mark_can_error ();
+  llvm::BasicBlock *body = fn.new_block ();
+  builder.SetInsertPoint (body);
+
+  llvm::Type *scalar_t = scalar->to_llvm ();
+  llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n);
+  llvm::Value *array = llvm::UndefValue::get (array_t);
+  for (size_t i = 0; i < n; ++i)
+    {
+      llvm::Value *idx = fn.argument (builder, i + 1);
+      array = builder.CreateInsertValue (array, idx, i);
+    }
+
+  llvm::Value *array_mem = builder.CreateAlloca (array_t);
+  builder.CreateStore (array, array_mem);
+  array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ());
+
+  llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n);
+  llvm::Value *mat = fn.argument (builder, 0);
+  llvm::Value *value = fn.argument (builder, n + 1);
+  llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value);
+  fn.do_return (builder, ret);
+  paren_subsasgn_fn.add_overload (fn);
+}
+
 #endif
--- a/src/jit-typeinfo.h	Tue Jul 31 11:51:01 2012 -0500
+++ b/src/jit-typeinfo.h	Tue Jul 31 15:40:52 2012 -0500
@@ -633,6 +633,8 @@
 
   void gen_subsref (const jit_function& paren_scalar, size_t n);
 
+  void gen_subsasgn (const jit_function& paren_scalar, size_t n);
+
   static jit_typeinfo *instance;
 
   llvm::Module *module;
--- a/src/pt-jit.cc	Tue Jul 31 11:51:01 2012 -0500
+++ b/src/pt-jit.cc	Tue Jul 31 15:40:52 2012 -0500
@@ -810,7 +810,8 @@
 }
 
 jit_instruction *
-jit_convert::resolve (const jit_operation& fres, tree_index_expression& exp)
+jit_convert::resolve (const jit_operation& fres, tree_index_expression& exp,
+                      jit_value *extra_arg)
 {
   std::string type = exp.type_tags ();
   if (! (type.size () == 1 && type[0] == '('))
@@ -832,7 +833,8 @@
 
   size_t narg = arg_list->size ();
   tree_argument_list::iterator iter = arg_list->begin ();
-  std::vector<jit_value *> call_args (narg + 1);
+  bool have_extra = extra_arg;
+  std::vector<jit_value *> call_args (narg + 1 + have_extra);
   call_args[0] = object;
 
   for (size_t idx = 0; iter != arg_list->end (); ++idx, ++iter)
@@ -844,6 +846,9 @@
       call_args[idx + 1] = visit (*iter);
     }
 
+  if (extra_arg)
+    call_args[call_args.size () - 1] = extra_arg;
+
   return create_checked (fres, call_args);
 }
 
@@ -858,7 +863,8 @@
   else if (tree_index_expression *idx
            = dynamic_cast<tree_index_expression *> (exp))
     {
-      jit_value *new_object = resolve (jit_typeinfo::paren_subsasgn (), *idx);
+      jit_value *new_object = resolve (jit_typeinfo::paren_subsasgn (), *idx,
+                                       rhs);
       do_assign (idx->expression (), new_object, true);
 
       // FIXME: Will not work for values that must be release/grabed
@@ -1862,4 +1868,19 @@
 %!   i = i + 1;
 %! endwhile
 %! assert (result == sum (sum (m)));
+
+%!test
+%! ndim = 100;
+%! m = zeros (ndim);
+%! i = 1;
+%! while (i <= ndim)
+%!   for j = 1:ndim
+%!     m(i, j) = (j - 1) * ndim + i;
+%!   endfor
+%!   i = i + 1;
+%! endwhile
+%! m2 = zeros (ndim);
+%! m2(:) = 1:(ndim^2);
+%! assert (all (m == m2));
+
 */
--- a/src/pt-jit.h	Tue Jul 31 11:51:01 2012 -0500
+++ b/src/pt-jit.h	Tue Jul 31 15:40:52 2012 -0500
@@ -297,7 +297,8 @@
   std::string next_name (const char *prefix, size_t& count, bool inc);
 
   jit_instruction *resolve (const jit_operation& fres,
-                            tree_index_expression& exp);
+                            tree_index_expression& exp,
+                            jit_value *extra_arg = 0);
 
   jit_value *do_assign (tree_expression *exp, jit_value *rhs,
                         bool artificial = false);