changeset 14974:e3cd4c9d7ccc

Generalize builtin specification in JIT and add support for cos and exp * src/ov-builtin.cc (octave_builtin::function): New function. * src/ov-builtin.h (octave_builtin::function): New declaration. * src/pt-jit.cc (gripe_bad_result, octave_jit_call, jit_typeinfo::add_builtin, jit_typeinfo::register_intrinsic, jit_typeinfo::find_builtin, jit_typeinfo::register_generic): New function. (jit_typeinfo::jit_typeinfo): Generalize builtin specification and add support for cos and exp. (jit_typeinfo::create_function): New overload. * src/pt-jit.h (overload::overload, jit_function::add_overload, jit_typeinfo::create_function): New overload. (jit_typeinfo::add_builtin, jit_typeinfo::register_intrinsic, jit_typeinfo::register_generic, jit_typeinfo::find_builtin): New declaration.
author Max Brister <max@2bass.com>
date Wed, 27 Jun 2012 14:14:20 -0500
parents 2960f1b2d6ea
children 95bfd032f4c7
files src/ov-builtin.cc src/ov-builtin.h src/pt-jit.cc src/pt-jit.h
diffstat 4 files changed, 199 insertions(+), 29 deletions(-) [+]
line wrap: on
line diff
--- a/src/ov-builtin.cc	Tue Jun 26 16:15:30 2012 -0500
+++ b/src/ov-builtin.cc	Wed Jun 27 14:14:20 2012 -0500
@@ -164,4 +164,10 @@
   jtype = &type;
 }
 
+octave_builtin::fcn
+octave_builtin::function (void) const
+{
+  return f;
+}
+
 const std::list<octave_lvalue> *octave_builtin::curr_lvalue_list = 0;
--- a/src/ov-builtin.h	Tue Jun 26 16:15:30 2012 -0500
+++ b/src/ov-builtin.h	Wed Jun 27 14:14:20 2012 -0500
@@ -80,6 +80,8 @@
 
   void stash_jit (jit_type& type);
 
+  fcn function (void) const;
+
   static const std::list<octave_lvalue> *curr_lvalue_list;
 
 protected:
--- a/src/pt-jit.cc	Tue Jun 26 16:15:30 2012 -0500
+++ b/src/pt-jit.cc	Wed Jun 27 14:14:20 2012 -0500
@@ -261,6 +261,55 @@
   std::cout << *m << std::endl;
 }
 
+static void
+gripe_bad_result (void)
+{
+  error ("incorrect type information given to the JIT compiler");
+}
+
+// FIXME: Add support for multiple outputs
+extern "C" octave_base_value *
+octave_jit_call (octave_builtin::fcn fn, size_t nargin,
+                 octave_base_value **argin, jit_type *result_type)
+{
+  octave_value_list ovl (nargin);
+  for (size_t i = 0; i < nargin; ++i)
+    ovl.xelem (i) = octave_value (argin[i]);
+
+  ovl = fn (ovl, 1);
+
+  // These type checks are not strictly required, but I'm guessing that
+  // incorrect types will be entered on occasion. This will be very difficult to
+  // debug unless we do the sanity check here.
+  if (result_type)
+    {
+      if (ovl.length () != 1)
+        {
+          gripe_bad_result ();
+          return 0;
+        }
+
+      octave_value& result = ovl.xelem (0);
+      jit_type *jtype = jit_typeinfo::join (jit_typeinfo::type_of (result),
+                                            result_type);
+      if (jtype != result_type)
+        {
+          gripe_bad_result ();
+          return 0;
+        }
+
+      octave_base_value *ret = result.internal_rep ();
+      ret->grab ();
+      return ret;
+    }
+
+  if (! (ovl.length () == 0
+         || (ovl.length () == 1 && ovl.xelem (0).is_undefined ())))
+    gripe_bad_result ();
+
+  return 0;
+}
+
 // -------------------- jit_range --------------------
 std::ostream&
 operator<< (std::ostream& os, const jit_range& rng)
@@ -408,8 +457,6 @@
   boolean = new_type ("bool", any, bool_t);
   index = new_type ("index", any, index_t);
 
-  sin_type = new_type ("sin", any, any_t);
-
   casts.resize (next_id + 1);
   identities.resize (next_id + 1, 0);
 
@@ -900,33 +947,27 @@
 
 
   // -------------------- builtin functions --------------------
-
-  // FIXME: Handling this here is messy, but it's the easiest way for now
-  // FIXME: Come up with a nicer way of defining functions
-  octave_value ov_sin = symbol_table::builtin_find ("sin");
-  octave_builtin *bsin
-    = dynamic_cast<octave_builtin *> (ov_sin.internal_rep ());
-  if (bsin)
+  add_builtin ("sin");
+  register_intrinsic ("sin", llvm::Intrinsic::sin, scalar, scalar);
+  register_generic ("sin", matrix, matrix);
+
+  add_builtin ("cos");
+  register_intrinsic ("cos", llvm::Intrinsic::cos, scalar, scalar);
+  register_generic ("cos", matrix, matrix);
+
+  add_builtin ("exp");
+  register_intrinsic ("exp", llvm::Intrinsic::cos, scalar, scalar);
+  register_generic ("exp", matrix, matrix);
+
+  casts.resize (next_id + 1);
+  fn = create_identity (any);
+  for (std::map<std::string, jit_type *>::iterator iter = builtins.begin ();
+       iter != builtins.end (); ++iter)
     {
-      bsin->stash_jit (*sin_type);
-
-      llvm::Function *isin
-        = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::sin,
-                                           llvm::makeArrayRef (scalar_t));
-      fn = create_function ("octave_jit_sin", scalar, any, scalar);
-      body = llvm::BasicBlock::Create (context, "body", fn);
-      builder.SetInsertPoint (body);
-      {
-        llvm::Value *ret = builder.CreateCall (isin, ++fn->arg_begin ());
-        builder.CreateRet (ret);
-      }
-      llvm::verifyFunction (*fn);
-      paren_subsref_fn.add_overload (fn, false, scalar, sin_type, scalar);
-      release_fn.add_overload (release_any, false, 0, sin_type);
-
-      fn = create_identity (any);
-      casts[any->type_id ()].add_overload (fn, false, any, sin_type);
-      casts[sin_type->type_id ()].add_overload (fn, false, sin_type, any);
+      jit_type *btype = iter->second;
+      release_fn.add_overload (release_any, false, 0, btype);
+      casts[any->type_id ()].add_overload (fn, false, any, btype);
+      casts[btype->type_id ()].add_overload (fn, false, btype, any);
     }
 }
 
@@ -1014,6 +1055,19 @@
 }
 
 llvm::Function *
+jit_typeinfo::create_function (const llvm::Twine& name, jit_type *ret,
+                               const std::vector<jit_type *>& args)
+{
+  llvm::Type *void_t = llvm::Type::getVoidTy (context);
+  std::vector<llvm::Type *> llvm_args (args.size (), void_t);
+  for (size_t i = 0; i < args.size (); ++i)
+    if (args[i])
+      llvm_args[i] = args[i]->to_llvm ();
+
+  return create_function (name, ret ? ret->to_llvm () : void_t, llvm_args);
+}
+
+llvm::Function *
 jit_typeinfo::create_function (const llvm::Twine& name, llvm::Type *ret,
                                const std::vector<llvm::Type *>& args)
 {
@@ -1051,6 +1105,74 @@
   return builder.CreateLoad (lerror_state);
 }
 
+void
+jit_typeinfo::add_builtin (const std::string& name)
+{
+  jit_type *btype = new_type (name, any, any->to_llvm ());
+  builtins[name] = btype;
+
+  octave_builtin *ov_builtin = find_builtin (name);
+  if (ov_builtin)
+    ov_builtin->stash_jit (*btype);
+}
+
+void
+jit_typeinfo::register_intrinsic (const std::string& name, size_t iid,
+                                  jit_type *result,
+                                  const std::vector<jit_type *>& args)
+{
+  jit_type *builtin_type = builtins[name];
+  size_t nargs = args.size ();
+  llvm::SmallVector<llvm::Type *, 5> llvm_args (nargs);
+  for (size_t i = 0; i < nargs; ++i)
+    llvm_args[i] = args[i]->to_llvm ();
+
+  llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID> (iid);
+  llvm::Function *ifun = llvm::Intrinsic::getDeclaration (module, id,
+                                                          llvm_args);
+  std::stringstream fn_name;
+  fn_name << "octave_jit_" << name;
+
+  std::vector<jit_type *> args1 (nargs + 1);
+  args1[0] = builtin_type;
+  std::copy (args.begin (), args.end (), args1.begin () + 1);
+
+  // The first argument will be the Octave function, but we already know that
+  // the function call is the equivalent of the intrinsic, so we ignore it and
+  // call the intrinsic with the remaining arguments.
+  llvm::Function *fn = create_function (fn_name.str (), result, args1);
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+
+  llvm::SmallVector<llvm::Value *, 5> fargs (nargs);
+  llvm::Function::arg_iterator iter = fn->arg_begin ();
+  ++iter;
+  for (size_t i = 0; i < nargs; ++i, ++iter)
+    fargs[i] = iter;
+
+  llvm::Value *ret = builder.CreateCall (ifun, fargs);
+  builder.CreateRet (ret);
+  llvm::verifyFunction (*fn);
+
+  paren_subsref_fn.add_overload (fn, false, result, args1);
+}
+
+octave_builtin *
+jit_typeinfo::find_builtin (const std::string& name)
+{
+  // FIXME: Finalize what we want to store in octave_builtin, then add functions
+  // to access these values in octave_value
+  octave_value ov_builtin = symbol_table::find (name);
+  return dynamic_cast<octave_builtin *> (ov_builtin.internal_rep ());
+}
+
+void
+jit_typeinfo::register_generic (const std::string&, jit_type *,
+                                const std::vector<jit_type *>&)
+{
+  // FIXME: Implement
+}
+
 jit_type *
 jit_typeinfo::do_type_of (const octave_value &ov) const
 {
--- a/src/pt-jit.h	Tue Jun 26 16:15:30 2012 -0500
+++ b/src/pt-jit.h	Wed Jun 27 14:14:20 2012 -0500
@@ -84,6 +84,7 @@
 }
 
 class octave_base_value;
+class octave_builtin;
 class octave_value;
 class tree;
 class tree_expression;
@@ -329,6 +330,11 @@
       arguments[2] = arg2;
     }
 
+    overload (llvm::Function *f, bool e, jit_type *r,
+              const std::vector<jit_type *>& aarguments)
+      : function (f), can_error (e), result (r), arguments (aarguments)
+    {}
+
     llvm::Function *function;
     bool can_error;
     jit_type *result;
@@ -360,6 +366,13 @@
     add_overload (ol);
   }
 
+  void add_overload (llvm::Function *f, bool e, jit_type *r,
+                     const std::vector<jit_type *>& args)
+  {
+    overload ol (f, e, r, args);
+    add_overload (ol);
+  }
+
   void add_overload (const overload& func,
                      const std::vector<jit_type*>& args);
 
@@ -660,6 +673,9 @@
     return create_function (name, ret, args);
   }
 
+  llvm::Function *create_function (const llvm::Twine& name, jit_type *ret,
+                                   const std::vector<jit_type *>& args);
+
   llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret,
                                    const std::vector<llvm::Type *>& args);
 
@@ -667,6 +683,30 @@
 
   llvm::Value *do_insert_error_check (void);
 
+  void add_builtin (const std::string& name);
+
+  void register_intrinsic (const std::string& name, size_t id,
+                           jit_type *result, jit_type *arg0)
+  {
+    std::vector<jit_type *> args (1, arg0);
+    register_intrinsic (name, id, result, args);
+  }
+
+  void register_intrinsic (const std::string& name, size_t id, jit_type *result,
+                           const std::vector<jit_type *>& args);
+
+  void register_generic (const std::string& name, jit_type *result,
+                         jit_type *arg0)
+  {
+    std::vector<jit_type *> args (1, arg0);
+    register_generic (name, result, args);
+  }
+
+  void register_generic (const std::string& name, jit_type *result,
+                         const std::vector<jit_type *>& args);
+
+  octave_builtin *find_builtin (const std::string& name);
+
   static jit_typeinfo *instance;
 
   llvm::Module *module;
@@ -683,7 +723,7 @@
   jit_type *string;
   jit_type *boolean;
   jit_type *index;
-  jit_type *sin_type;
+  std::map<std::string, jit_type *> builtins;
 
   std::vector<jit_function> binary_ops;
   jit_function grab_fn;