changeset 15078:fe4752f772e2

Generate ND indexing functions on demand in JIT. * src/jit-typeinfo.cc (jit_operation::~jit_operation, jit_operation::do_generate, jit_operation::generate, jit_operation::signature_cmp::operator()): New function. (jit_operation::overload): Call do_generate when lookup fails. (jit_index_operation, jit_paren_subsref, jit_paren_subsasgn): New class. (jit_typeinfo::jit_typeinfo): Update to use jit_paren_subsref and jit_paren_subsasgn. (jit_typeinfo::gen_subsref, jit_typeinfo::gen_subsasgn): Removed functions. * src/jit-typeinfo.h (jit_operation::~jit_operation, jit_operation::generate, jit_operation::do_generate): New declaration. (jit_operation::add_overload, jit_operation::overload, jit_operation::result, jit_operation::to_idx): Use signature_vec typedef. (jit_operation::singature_cmp): New class. (jit_index_operation, jit_paren_subsref, jit_paren_subsasgn): New class. (jit_typeinfo::get_scalar_ptr): Nwe function. (jit_typeinfo::gen_subsref, jit_typeinfo::gen_subsasgn): Removed declaration. * src/pt-jit.cc: New test.
author Max Brister <max@2bass.com>
date Wed, 01 Aug 2012 17:00:12 -0500
parents f0b04a20d7cf
children dda73cb60ac5
files src/jit-typeinfo.cc src/jit-typeinfo.h src/pt-jit.cc
diffstat 3 files changed, 292 insertions(+), 108 deletions(-) [+]
line wrap: on
line diff
--- a/src/jit-typeinfo.cc	Wed Aug 01 12:10:26 2012 -0400
+++ b/src/jit-typeinfo.cc	Wed Aug 01 17:00:12 2012 -0500
@@ -708,6 +708,16 @@
 }
 
 // -------------------- jit_operation --------------------
+jit_operation::~jit_operation (void)
+{
+  for (generated_map::iterator iter = generated.begin ();
+       iter != generated.end (); ++iter)
+    {
+      delete iter->first;
+      delete iter->second;
+    }
+}
+
 void
 jit_operation::add_overload (const jit_function& func,
                             const std::vector<jit_type*>& args)
@@ -742,23 +752,26 @@
 const jit_function&
 jit_operation::overload (const std::vector<jit_type*>& types) const
 {
-  // FIXME: We should search for the next best overload on failure
   static jit_function null_overload;
-  if (types.size () >= overloads.size ())
-    return null_overload;
-
   for (size_t i  =0; i < types.size (); ++i)
     if (! types[i])
       return null_overload;
 
+  if (types.size () >= overloads.size ())
+    return do_generate (types);
+
   const Array<jit_function>& over = overloads[types.size ()];
   dim_vector dv (over.dims ());
   Array<octave_idx_type> idx = to_idx (types);
   for (octave_idx_type i = 0; i < dv.length (); ++i)
     if (idx(i) >= dv(i))
-      return null_overload;
+      return do_generate (types);
 
-  return over(idx);
+  const jit_function& ret = over(idx);
+  if (! ret.valid ())
+    return do_generate (types);
+
+  return ret;
 }
 
 Array<octave_idx_type>
@@ -782,6 +795,175 @@
   return idx;
 }
 
+const jit_function&
+jit_operation::do_generate (const signature_vec& types) const
+{
+  static jit_function null_overload;
+  generated_map::const_iterator find = generated.find (&types);
+  if (find != generated.end ())
+    {
+      if (find->second)
+        return *find->second;
+      else
+        return null_overload;
+    }
+
+  jit_function *ret = generate (types);
+  generated[new signature_vec (types)] = ret;
+  return ret ? *ret : null_overload;
+}
+
+jit_function *
+jit_operation::generate (const signature_vec& types) const
+{
+  return 0;
+}
+
+bool
+jit_operation::signature_cmp
+::operator() (const signature_vec *lhs, const signature_vec *rhs)
+{
+  const signature_vec& l = *lhs;
+  const signature_vec& r = *rhs;
+
+  if (l.size () < r.size ())
+    return true;
+  else if (l.size () > r.size ())
+    return false;
+
+  for (size_t i = 0; i < l.size (); ++i)
+    {
+      if (l[i]->type_id () < r[i]->type_id ())
+        return true;
+      else if (l[i]->type_id () > r[i]->type_id ())
+        return false;
+    }
+
+  return false;
+}
+
+// -------------------- jit_index_operation --------------------
+jit_function *
+jit_index_operation::generate (const signature_vec& types) const
+{
+  if (types.size () > 2 && types[0] == jit_typeinfo::get_matrix ())
+    {
+      // indexing a matrix with scalars
+      jit_type *scalar = jit_typeinfo::get_scalar ();
+      for (size_t i = 1; i < types.size (); ++i)
+        if (types[i] != scalar)
+          return 0;
+
+      return generate_matrix (types);
+    }
+
+  return 0;
+}
+
+llvm::Value *
+jit_index_operation::create_arg_array (llvm::IRBuilderD& builder,
+                                       const jit_function &fn, size_t start_idx,
+                                       size_t end_idx) const
+{
+  size_t n = end_idx - start_idx;
+  llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm ();
+  llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n);
+  llvm::Value *array = llvm::UndefValue::get (array_t);
+  for (size_t i = start_idx; i < end_idx; ++i)
+    {
+      llvm::Value *idx = fn.argument (builder, i);
+      array = builder.CreateInsertValue (array, idx, i - start_idx);
+    }
+
+  llvm::Value *array_mem = builder.CreateAlloca (array_t);
+  builder.CreateStore (array, array_mem);
+  return builder.CreateBitCast (array_mem, scalar_t->getPointerTo ());
+}
+
+// -------------------- jit_paren_subsref --------------------
+jit_function *
+jit_paren_subsref::generate_matrix (const signature_vec& types) const
+{
+  std::stringstream ss;
+  ss << "jit_paren_subsref_matrix_scalar" << (types.size () - 1);
+
+  jit_type *scalar = jit_typeinfo::get_scalar ();
+  jit_function *fn = new jit_function (module, jit_convention::internal,
+                                       ss.str (), scalar, types);
+  fn->mark_can_error ();
+  llvm::BasicBlock *body = fn->new_block ();
+  llvm::IRBuilder<> builder (body);
+
+  llvm::Value *array = create_arg_array (builder, *fn, 1, types.size ());
+  jit_type *index = jit_typeinfo::get_index ();
+  llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (),
+                                               types.size () - 1);
+  llvm::Value *mat = fn->argument (builder, 0);
+  llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem);
+  fn->do_return (builder, ret);
+  return fn;
+}
+
+void
+jit_paren_subsref::do_initialize (void)
+{
+  std::vector<jit_type *> types (3);
+  types[0] = jit_typeinfo::get_matrix ();
+  types[1] = jit_typeinfo::get_scalar_ptr ();
+  types[2] = jit_typeinfo::get_index ();
+
+  jit_type *scalar = jit_typeinfo::get_scalar ();
+  paren_scalar = jit_function (module, jit_convention::external,
+                               "octave_jit_paren_scalar", scalar, types);
+  paren_scalar.add_mapping (engine, &octave_jit_paren_scalar);
+  paren_scalar.mark_can_error ();
+}
+
+// -------------------- jit_paren_subsasgn --------------------
+jit_function *
+jit_paren_subsasgn::generate_matrix (const signature_vec& types) const
+{
+  std::stringstream ss;
+  ss << "jit_paren_subsasgn_matrix_scalar" << (types.size () - 2);
+
+  jit_type *matrix = jit_typeinfo::get_matrix ();
+  jit_function *fn = new jit_function (module, jit_convention::internal,
+                                       ss.str (), matrix, types);
+  fn->mark_can_error ();
+  llvm::BasicBlock *body = fn->new_block ();
+  llvm::IRBuilder<> builder (body);
+
+  llvm::Value *array = create_arg_array (builder, *fn, 1, types.size () - 1);
+  jit_type *index = jit_typeinfo::get_index ();
+  llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (),
+                                               types.size () - 2);
+
+  llvm::Value *mat = fn->argument (builder, 0);
+  llvm::Value *value = fn->argument (builder, types.size () - 1);
+  llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value);
+  fn->do_return (builder, ret);
+  return fn;
+}
+
+void
+jit_paren_subsasgn::do_initialize (void)
+{
+  if (paren_scalar.valid ())
+    return;
+
+  jit_type *matrix = jit_typeinfo::get_matrix ();
+  std::vector<jit_type *> types (4);
+  types[0] = matrix;
+  types[1] = jit_typeinfo::get_scalar_ptr ();
+  types[2] = jit_typeinfo::get_index ();
+  types[3] = jit_typeinfo::get_scalar ();
+
+  paren_scalar = jit_function (module, jit_convention::external,
+                               "octave_jit_paren_scalar", matrix, types);
+  paren_scalar.add_mapping (engine, &octave_jit_paren_scalar_subsasgn);
+  paren_scalar.mark_can_error ();
+}
+
 // -------------------- jit_typeinfo --------------------
 void
 jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e)
@@ -835,14 +1017,12 @@
   matrix = new_type ("matrix", any, matrix_t);
   complex = new_type ("complex", any, complex_t);
   scalar = new_type ("scalar", complex, scalar_t);
+  scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ());
   range = new_type ("range", any, range_t);
   string = new_type ("string", any, string_t);
   boolean = new_type ("bool", any, bool_t);
   index = new_type ("index", any, index_t);
 
-  // a fake type for interfacing with C++
-  jit_type *scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ());
-
   create_int (8);
   create_int (16);
   create_int (32);
@@ -867,6 +1047,9 @@
   if (sizeof (void *) == 4)
     complex->mark_sret ();
 
+  paren_subsref_fn.initialize (module, engine);
+  paren_subsasgn_fn.initialize (module, engine);
+
   // bind global variables
   lerror_state = new llvm::GlobalVariable (*module, bool_t, false,
                                            llvm::GlobalValue::ExternalLinkage,
@@ -1364,28 +1547,6 @@
   }
   paren_subsref_fn.add_overload (fn);
 
-  // generate () subsref for ND indexing of matricies with scalars
-  jit_function paren_scalar = create_function (jit_convention::external,
-                                               "octave_jit_paren_scalar",
-                                               scalar, matrix, scalar_ptr,
-                                               index);
-  paren_scalar.add_mapping (engine, &octave_jit_paren_scalar);
-  paren_scalar.mark_can_error ();
-
-  jit_function paren_scalar_subsasgn
-    = create_function (jit_convention::external,
-                       "octave_jit_paren_scalar_subsasgn", matrix, matrix,
-                       scalar_ptr, index, scalar);
-  paren_scalar_subsasgn.add_mapping (engine, &octave_jit_paren_scalar_subsasgn);
-  paren_scalar_subsasgn.mark_can_error ();
-
-  // FIXME: Generate this on the fly
-  for (size_t i = 2; i < 10; ++i)
-    {
-      gen_subsref (paren_scalar, i);
-      gen_subsasgn (paren_scalar_subsasgn, i);
-    }
-
   // paren subsasgn
   paren_subsasgn_fn.stash_name ("()subsasgn");
 
@@ -1907,71 +2068,4 @@
   return ret;
 }
 
-void
-jit_typeinfo::gen_subsref (const jit_function& paren_scalar, size_t n)
-{
-  std::stringstream name;
-  name << "jit_paren_subsref_matrix_scalar" << n;
-  std::vector<jit_type *> args (n + 1, scalar);
-  args[0] = matrix;
-  jit_function fn = create_function (jit_convention::internal, name.str (),
-                                     scalar, args);
-  fn.mark_can_error ();
-  llvm::BasicBlock *body = fn.new_block ();
-  builder.SetInsertPoint (body);
-
-  llvm::Type *scalar_t = scalar->to_llvm ();
-  llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n);
-  llvm::Value *array = llvm::UndefValue::get (array_t);
-  for (size_t i = 0; i < n; ++i)
-    {
-      llvm::Value *idx = fn.argument (builder, i + 1);
-      array = builder.CreateInsertValue (array, idx, i);
-    }
-
-  llvm::Value *array_mem = builder.CreateAlloca (array_t);
-  builder.CreateStore (array, array_mem);
-  array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ());
-
-  llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n);
-  llvm::Value *mat = fn.argument (builder, 0);
-  llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem);
-  fn.do_return (builder, ret);
-  paren_subsref_fn.add_overload (fn);
-}
-
-void
-jit_typeinfo::gen_subsasgn (const jit_function& paren_scalar, size_t n)
-{
-  std::stringstream name;
-  name << "jit_paren_subsasgn_matrix_scalar" << n;
-  std::vector<jit_type *> args (n + 2, scalar);
-  args[0] = matrix;
-  jit_function fn = create_function (jit_convention::internal, name.str (),
-                                     matrix, args);
-  fn.mark_can_error ();
-  llvm::BasicBlock *body = fn.new_block ();
-  builder.SetInsertPoint (body);
-
-  llvm::Type *scalar_t = scalar->to_llvm ();
-  llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n);
-  llvm::Value *array = llvm::UndefValue::get (array_t);
-  for (size_t i = 0; i < n; ++i)
-    {
-      llvm::Value *idx = fn.argument (builder, i + 1);
-      array = builder.CreateInsertValue (array, idx, i);
-    }
-
-  llvm::Value *array_mem = builder.CreateAlloca (array_t);
-  builder.CreateStore (array, array_mem);
-  array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ());
-
-  llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n);
-  llvm::Value *mat = fn.argument (builder, 0);
-  llvm::Value *value = fn.argument (builder, n + 1);
-  llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value);
-  fn.do_return (builder, ret);
-  paren_subsasgn_fn.add_overload (fn);
-}
-
 #endif
--- a/src/jit-typeinfo.h	Wed Aug 01 12:10:26 2012 -0400
+++ b/src/jit-typeinfo.h	Wed Aug 01 17:00:12 2012 -0500
@@ -314,17 +314,22 @@
 jit_operation
 {
 public:
+  // type signature vector
+  typedef std::vector<jit_type *> signature_vec;
+
+  virtual ~jit_operation (void);
+
   void add_overload (const jit_function& func)
   {
     add_overload (func, func.arguments ());
   }
 
   void add_overload (const jit_function& func,
-                     const std::vector<jit_type*>& args);
+                     const signature_vec& args);
 
-  const jit_function& overload (const std::vector<jit_type *>& types) const;
+  const jit_function& overload (const signature_vec& types) const;
 
-  jit_type *result (const std::vector<jit_type *>& types) const
+  jit_type *result (const signature_vec& types) const
   {
     const jit_function& temp = overload (types);
     return temp.result ();
@@ -346,14 +351,79 @@
   const std::string& name (void) const { return mname; }
 
   void stash_name (const std::string& aname) { mname = aname; }
+protected:
+  virtual jit_function *generate (const signature_vec& types) const;
 private:
-  Array<octave_idx_type> to_idx (const std::vector<jit_type*>& types) const;
+  Array<octave_idx_type> to_idx (const signature_vec& types) const;
+
+  const jit_function& do_generate (const signature_vec& types) const;
+
+  struct signature_cmp
+  {
+    bool operator() (const signature_vec *lhs, const signature_vec *rhs);
+  };
+
+  typedef std::map<const signature_vec *, jit_function *, signature_cmp>
+  generated_map;
+
+  mutable generated_map generated;
 
   std::vector<Array<jit_function> > overloads;
 
   std::string mname;
 };
 
+class
+jit_index_operation : public jit_operation
+{
+public:
+  jit_index_operation (void) : module (0), engine (0) {}
+
+  void initialize (llvm::Module *amodule, llvm::ExecutionEngine *aengine)
+  {
+    module = amodule;
+    engine = aengine;
+    do_initialize ();
+  }
+protected:
+  virtual jit_function *generate (const signature_vec& types) const;
+
+  virtual jit_function *generate_matrix (const signature_vec& types) const = 0;
+
+  virtual void do_initialize (void) = 0;
+
+  // helper functions
+  // [start_idx, end_idx).
+  llvm::Value *create_arg_array (llvm::IRBuilderD& builder,
+                                 const jit_function &fn, size_t start_idx,
+                                 size_t end_idx) const;
+
+  llvm::Module *module;
+  llvm::ExecutionEngine *engine;
+};
+
+class
+jit_paren_subsref : public jit_index_operation
+{
+protected:
+  virtual jit_function *generate_matrix (const signature_vec& types) const;
+
+  virtual void do_initialize (void);
+private:
+  jit_function paren_scalar;
+};
+
+class
+jit_paren_subsasgn : public jit_index_operation
+{
+protected:
+  jit_function *generate_matrix (const signature_vec& types) const;
+
+  virtual void do_initialize (void);
+private:
+  jit_function paren_scalar;
+};
+
 // A singleton class which handles the construction of jit_types and
 // jit_operations.
 class
@@ -376,6 +446,8 @@
   static llvm::Type *get_scalar_llvm (void)
   { return instance->scalar->to_llvm (); }
 
+  static jit_type *get_scalar_ptr (void) { return instance->scalar_ptr; }
+
   static jit_type *get_range (void) { return instance->range; }
 
   static jit_type *get_string (void) { return instance->string; }
@@ -631,10 +703,6 @@
 
   jit_type *intN (size_t nbits) const;
 
-  void gen_subsref (const jit_function& paren_scalar, size_t n);
-
-  void gen_subsasgn (const jit_function& paren_scalar, size_t n);
-
   static jit_typeinfo *instance;
 
   llvm::Module *module;
@@ -647,6 +715,7 @@
   jit_type *any;
   jit_type *matrix;
   jit_type *scalar;
+  jit_type *scalar_ptr; // a fake type for interfacing with C++
   jit_type *range;
   jit_type *string;
   jit_type *boolean;
@@ -667,8 +736,8 @@
   jit_operation for_index_fn;
   jit_operation logically_true_fn;
   jit_operation make_range_fn;
-  jit_operation paren_subsref_fn;
-  jit_operation paren_subsasgn_fn;
+  jit_paren_subsref paren_subsref_fn;
+  jit_paren_subsasgn paren_subsasgn_fn;
   jit_operation end_fn;
 
   // type id -> cast function TO that type
--- a/src/pt-jit.cc	Wed Aug 01 12:10:26 2012 -0400
+++ b/src/pt-jit.cc	Wed Aug 01 17:00:12 2012 -0500
@@ -1883,4 +1883,25 @@
 %! m2(:) = 1:(ndim^2);
 %! assert (all (m == m2));
 
+%!test
+%! ndim = 2;
+%! m = zeros (ndim, ndim, ndim, ndim);
+%! result = 0;
+%! i0 = 1;
+%! while (i0 <= ndim)
+%!   for i1 = 1:ndim
+%!     for i2 = 1:ndim
+%!       for i3 = 1:ndim
+%!         m(i0, i1, i2, i3) = 1;
+%!         m(i0, i1, i2, i3, 1, 1, 1, 1, 1, 1) = 1;
+%!         result = result + m(i0, i1, i2, i3);
+%!       endfor
+%!     endfor
+%!   endfor
+%!   i0 = i0 + 1;
+%! endwhile
+%! expected = ones (ndim, ndim, ndim, ndim);
+%! assert (all (m == expected));
+%! assert (result == sum (expected (:)));
+
 */