changeset 14936:32deb562ae77

Allow for construction of ranges during jit
author Max Brister <max@2bass.com>
date Mon, 04 Jun 2012 17:18:47 -0500
parents 5801e031a3b5
children 78e1457c5bf5
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 111 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc	Mon Jun 04 13:10:44 2012 -0500
+++ b/src/pt-jit.cc	Mon Jun 04 17:18:47 2012 -0500
@@ -130,6 +130,13 @@
   return rep;
 }
 
+extern "C" octave_idx_type
+octave_jit_compute_nelem (double base, double limit, double inc)
+{
+  Range rng = Range (base, limit, inc);
+  return rng.nelem (); 
+}
+
 extern "C" void
 octave_jit_release_any (octave_base_value *obv)
 {
@@ -460,6 +467,37 @@
   logically_true.add_overload (fn, false, false, boolean, boolean);
   logically_true.stash_name ("logically_true");
 
+  // make_range
+  // FIXME: May be benificial to implement all in LLVM
+  make_range_fn.stash_name ("make_range");
+  llvm::Function *compute_nelem
+    = create_function ("octave_jit_compute_nelem", index, scalar, scalar, scalar);
+  engine->addGlobalMapping (compute_nelem,
+                            reinterpret_cast<void*> (&octave_jit_compute_nelem));
+
+  fn = create_function ("octave_jit_make_range", range, scalar, scalar, scalar);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Function::arg_iterator args = fn->arg_begin ();
+    llvm::Value *base = args;
+    llvm::Value *limit = ++args;
+    llvm::Value *inc = ++args;
+    llvm::Value *nelem = builder.CreateCall3 (compute_nelem, base, limit, inc);
+
+    llvm::Value *dzero = llvm::ConstantFP::get (dbl, 0);
+    llvm::Value *izero = llvm::ConstantInt::get (index_t, 0);
+    llvm::Value *rng = llvm::ConstantStruct::get (range_t, dzero, dzero, dzero,
+                                                  izero, NULL);
+    rng = builder.CreateInsertValue (rng, base, 0);
+    rng = builder.CreateInsertValue (rng, limit, 1);
+    rng = builder.CreateInsertValue (rng, inc, 2);
+    rng = builder.CreateInsertValue (rng, nelem, 3);
+    builder.CreateRet (rng);
+  }
+  llvm::verifyFunction (*fn);
+  make_range_fn.add_overload (fn, false, false, range, scalar, scalar, scalar);
+
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
 
@@ -1058,6 +1096,8 @@
         final_block->append (create<jit_store_argument> (var));
     }
 
+  print_blocks ("octave jit ir");
+
   construct_ssa (final_block);
 
   // initialize the worklist to instructions derived from constants
@@ -1148,9 +1188,25 @@
 }
 
 void
-jit_convert::visit_colon_expression (tree_colon_expression&)
+jit_convert::visit_colon_expression (tree_colon_expression& expr)
 {
-  fail ();
+  // in the futher we need to add support for classes and deal with rvalues
+  jit_instruction *base = visit (expr.base ());
+  jit_instruction *limit = visit (expr.limit ());
+  jit_instruction *increment;
+  tree_expression *tinc = expr.increment ();
+
+  if (tinc)
+    increment = visit (tinc);
+  else
+    {
+      increment = create<jit_const_scalar> (1);
+      block->append (increment);
+    }
+
+  result = block->append (create<jit_call> (jit_typeinfo::make_range, base,
+                                            limit, increment));
+                                            
 }
 
 void
--- a/src/pt-jit.h	Mon Jun 04 13:10:44 2012 -0500
+++ b/src/pt-jit.h	Mon Jun 04 17:18:47 2012 -0500
@@ -95,8 +95,6 @@
 struct
 jit_range
 {
-  jit_range (void) {}
-
   jit_range (const Range& from) : base (from.base ()), limit (from.limit ()),
     inc (from.inc ()), nelem (from.nelem ())
     {}
@@ -182,6 +180,16 @@
       arguments[1] = arg1;
     }
 
+    overload (llvm::Function *f, bool e, bool s, jit_type *r, jit_type *arg0,
+              jit_type *arg1, jit_type *arg2) : function (f), can_error (e),
+                                                side_effects (s), result (r),
+                                                arguments (3)
+    {
+      arguments[0] = arg0;
+      arguments[1] = arg1;
+      arguments[2] = arg2;
+    }
+
     llvm::Function *function;
     bool can_error;
     bool side_effects;
@@ -207,6 +215,13 @@
     add_overload (ol);
   }
 
+  void add_overload (llvm::Function *f, bool e, bool s, jit_type *r, jit_type *arg0,
+                     jit_type *arg1, jit_type *arg2)
+  {
+    overload ol (f, e, s, r, arg0, arg1, arg2);
+    add_overload (ol);
+  }
+
   void add_overload (const overload& func,
                      const std::vector<jit_type*>& args);
 
@@ -320,6 +335,11 @@
     return instance->for_index_fn;
   }
 
+  static const jit_function& make_range (void)
+  {
+    return instance->make_range_fn;
+  }
+
   static const jit_function& cast (jit_type *result)
   {
     return instance->do_cast (result);
@@ -484,6 +504,7 @@
   jit_function for_check_fn;
   jit_function for_index_fn;
   jit_function logically_true;
+  jit_function make_range_fn;
 
   // type id -> cast function TO that type
   std::vector<jit_function> casts;
@@ -776,6 +797,17 @@
     stash_argument (2, arg2);
   }
 
+  jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2,
+                   jit_value *arg3)
+    : already_infered (3, reinterpret_cast<jit_type *>(0)), arguments (4), 
+      id (next_id ()), mparent (0)
+  {
+    stash_argument (0, arg0);
+    stash_argument (1, arg1);
+    stash_argument (2, arg2);
+    stash_argument (3, arg3);
+  }
+
   static void reset_ids (void)
   {
     next_id (true);
@@ -1441,6 +1473,15 @@
             jit_value *arg0, jit_value *arg1) : jit_instruction (arg0, arg1),
                                                 mfunction (afunction ()) {}
 
+  jit_call (const jit_function& (*afunction) (void),
+            jit_value *arg0, jit_value *arg1, jit_value *arg2)
+    : jit_instruction (arg0, arg1, arg2), mfunction (afunction ()) {}
+
+  jit_call (const jit_function& (*afunction) (void),
+            jit_value *arg0, jit_value *arg1, jit_value *arg2, jit_value *arg3)
+    : jit_instruction (arg0, arg1, arg2, arg3), mfunction (afunction ()) {}
+                                                
+
   const jit_function& function (void) const { return mfunction; }
 
   bool has_side_effects (void) const
@@ -1722,6 +1763,16 @@
     track_value (ret);
     return ret;
   }
+
+  template <typename T, typename ARG0, typename ARG1, typename ARG2,
+            typename ARG3>
+  T *create (const ARG0& arg0, const ARG1& arg1, const ARG2& arg2,
+             const ARG3& arg3)
+  {
+    T *ret = new T(arg0, arg1, arg2, arg3);
+    track_value (ret);
+    return ret;
+  }
 private:
   typedef std::list<jit_block *> block_list;
   typedef block_list::iterator block_iterator;