# HG changeset patch # User Max Brister # Date 1337896090 21600 # Node ID f0499b0af64605ede75609f4cd4be264f4379cd2 # Parent 13465aab507f4773b7eb3a0f6a2fc7721d97d5eb# Parent 22244a235fd09f9d17898508666bd2491a6aecf8 maint: Periodic merge of default to jit diff -r 22244a235fd0 -r f0499b0af646 build-aux/common.mk --- a/build-aux/common.mk Thu May 24 15:38:59 2012 -0400 +++ b/build-aux/common.mk Thu May 24 15:48:10 2012 -0600 @@ -544,6 +544,9 @@ -e "s|%OCTAVE_CONF_MAGICK_CPPFLAGS%|\"${MAGICK_CPPFLAGS}\"|" \ -e "s|%OCTAVE_CONF_MAGICK_LDFLAGS%|\"${MAGICK_LDFLAGS}\"|" \ -e "s|%OCTAVE_CONF_MAGICK_LIBS%|\"${MAGICK_LIBS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_CPPFLAGS%|\"${LLVM_CPPFLAGS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_LDFLAGS%|\"${LLVM_LDFLAGS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_LIBS%|\"${LLVM_LIBS}\"|" \ -e 's|%OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS%|\"@MKOCTFILE_DL_LDFLAGS@\"|' \ -e "s|%OCTAVE_CONF_OCTAVE_LINK_DEPS%|\"${OCTAVE_LINK_DEPS}\"|" \ -e "s|%OCTAVE_CONF_OCTAVE_LINK_OPTS%|\"${OCTAVE_LINK_OPTS}\"|" \ diff -r 22244a235fd0 -r f0499b0af646 configure.ac --- a/configure.ac Thu May 24 15:38:59 2012 -0400 +++ b/configure.ac Thu May 24 15:48:10 2012 -0600 @@ -712,6 +712,59 @@ [ZLIB library not found. Octave will not be able to save or load compressed data files or HDF5 files.], [zlib.h], [gzclearerr]) +### Check for the llvm library +dnl +dnl +dnl llvm is odd and has its own pkg-config like script. We should probably check +dnl for existance and +dnl + +LLVM_CONFIG=llvm-config +LLVM_CPPFLAGS= +LLVM_LDFLAGS= +LLVM_LIBS= + +LLVM_LDFLAGS=`$LLVM_CONFIG --ldflags` +LLVM_LIBS=`$LLVM_CONFIG --libs` +LLVM_CPPFLAGS=`$LLVM_CONFIG --cxxflags` + +warn_llvm="LLVM library fails tests. JIT compilation will be disabled." + +save_CPPFLAGS="$CPPFLAGS" +save_LIBS="$LIBS" +save_LDFLAGS="$LDFLAGS" +CPPFLAGS="$LLVM_CPPFLAGS $CPPFLAGS" +LIBS="$LLVM_LIBS $LIBS" +LDFLAGS="$LLVM_LDFLAGS $LDFLAGS" +AC_LANG_PUSH(C++) + AC_CHECK_HEADER([llvm/LLVMContext.h], [ + AC_MSG_CHECKING([for llvm::getGlobalContext in llvm/LLVMContext.h]) + AC_TRY_LINK([#include ], [llvm::getGlobalContext ()], [ + AC_MSG_RESULT(yes) + warn_llvm= + ], [ + AC_MSG_RESULT(no) + ]) + ]) +AC_LANG_POP(C++) +CPPFLAGS="$save_CPPFLAGS" +LIBS="$save_LIBS" +LDFLAGS="$save_LDFLAGS" + +if test -z "$warn_llvm"; then + AC_DEFINE(HAVE_LLVM, 1, [Define if LLVM is available.]) +else + LLVM_CPPFLAGS= + LLVM_LDFLAGS= + LLVM_LIBS= + AC_MSG_WARN([$warn_llvm]) +fi +AC_DEFINE(HAVE_LLVM, 1, [Define if LLVM is available.]) + +AC_SUBST(LLVM_CPPFLAGS) +AC_SUBST(LLVM_LDFLAGS) +AC_SUBST(LLVM_LIBS) + ### Check for HDF5 library. save_CPPFLAGS="$CPPFLAGS" @@ -2242,6 +2295,9 @@ Magick++ CPPFLAGS: $MAGICK_CPPFLAGS Magick++ LDFLAGS: $MAGICK_LDFLAGS Magick++ libraries: $MAGICK_LIBS + LLVM CPPFLAGS: $LLVM_CPPFLAGS + LLVM LDFLAGS: $LLVM_LDFLAGS + LLVM Libraries: $LLVM_LIBS HDF5 CPPFLAGS: $HDF5_CPPFLAGS HDF5 LDFLAGS: $HDF5_LDFLAGS HDF5 libraries: $HDF5_LIBS diff -r 22244a235fd0 -r f0499b0af646 src/Makefile.am --- a/src/Makefile.am Thu May 24 15:38:59 2012 -0400 +++ b/src/Makefile.am Thu May 24 15:48:10 2012 -0600 @@ -220,6 +220,7 @@ pt-fcn-handle.h \ pt-id.h \ pt-idx.h \ + pt-jit.h \ pt-jump.h \ pt-loop.h \ pt-mat.h \ @@ -392,6 +393,7 @@ pt-fcn-handle.cc \ pt-id.cc \ pt-idx.cc \ + pt-jit.cc \ pt-jump.cc \ pt-loop.cc \ pt-mat.cc \ diff -r 22244a235fd0 -r f0499b0af646 src/TEMPLATE-INST/Array-jit.cc --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/TEMPLATE-INST/Array-jit.cc Thu May 24 15:48:10 2012 -0600 @@ -0,0 +1,34 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "Array.h" +#include "Array.cc" + +#include "pt-jit.h" + +NO_INSTANTIATE_ARRAY_SORT (jit_function::overload); + +INSTANTIATE_ARRAY (jit_function::overload, OCTINTERP_API); diff -r 22244a235fd0 -r f0499b0af646 src/TEMPLATE-INST/module.mk --- a/src/TEMPLATE-INST/module.mk Thu May 24 15:38:59 2012 -0400 +++ b/src/TEMPLATE-INST/module.mk Thu May 24 15:48:10 2012 -0600 @@ -2,4 +2,5 @@ TEMPLATE_INST_SRC = \ TEMPLATE-INST/Array-os.cc \ - TEMPLATE-INST/Array-tc.cc + TEMPLATE-INST/Array-tc.cc \ + TEMPLATE-INST/Array-jit.cc diff -r 22244a235fd0 -r f0499b0af646 src/link-deps.mk --- a/src/link-deps.mk Thu May 24 15:38:59 2012 -0400 +++ b/src/link-deps.mk Thu May 24 15:48:10 2012 -0600 @@ -13,14 +13,16 @@ $(Z_LIBS) \ $(OPENGL_LIBS) \ $(X11_LIBS) \ - $(CARBON_LIBS) + $(CARBON_LIBS) \ + $(LLVM_LIBS) LIBOCTINTERP_LINK_OPTS = \ $(GRAPHICS_LDFLAGS) \ $(FT2_LDFLAGS) \ $(HDF5_LDFLAGS) \ $(Z_LDFLAGS) \ - $(REGEX_LDFLAGS) + $(REGEX_LDFLAGS) \ + $(LLVM_LDFLAGS) OCT_LINK_DEPS = diff -r 22244a235fd0 -r f0499b0af646 src/oct-conf.in.h --- a/src/oct-conf.in.h Thu May 24 15:38:59 2012 -0400 +++ b/src/oct-conf.in.h Thu May 24 15:48:10 2012 -0600 @@ -384,6 +384,18 @@ #define OCTAVE_CONF_MAGICK_LIBS %OCTAVE_CONF_MAGICK_LIBS% #endif +#ifndef OCTAVE_CONF_LLVM_CPPFLAGS +#define OCTAVE_CONF_LLVM_CPPFLAGS %OCTAVE_CONF_LLVM_CPPFLAGS% +#endif + +#ifndef OCTAVE_CONF_LLVM_LDFLAGS +#define OCTAVE_CONF_LLVM_LDFLAGS %OCTAVE_CONF_LLVM_LDFLAGS% +#endif + +#ifndef OCTAVE_CONF_LLVM_LIBS +#define OCTAVE_CONF_LLVM_LIBS %OCTAVE_CONF_LLVM_LIBS% +#endif + #ifndef OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS #define OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS %OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS% #endif diff -r 22244a235fd0 -r f0499b0af646 src/ov-base.h --- a/src/ov-base.h Thu May 24 15:38:59 2012 -0400 +++ b/src/ov-base.h Thu May 24 15:48:10 2012 -0600 @@ -755,6 +755,21 @@ virtual bool fast_elem_insert_self (void *where, builtin_type_t btyp) const; + // Grab the reference count. For use by jit. + void + grab (void) + { + ++count; + } + + // Release the reference count. For use by jit. + void + release (void) + { + if (--count == 0) + delete this; + } + protected: // This should only be called for derived types. diff -r 22244a235fd0 -r f0499b0af646 src/pt-eval.cc --- a/src/pt-eval.cc Thu May 24 15:38:59 2012 -0400 +++ b/src/pt-eval.cc Thu May 24 15:48:10 2012 -0600 @@ -44,6 +44,10 @@ #include "symtab.h" #include "unwind-prot.h" +//FIXME: This should be part of tree_evaluator +#include "pt-jit.h" +static tree_jit jiter; + static tree_evaluator std_evaluator; tree_evaluator *current_evaluator = &std_evaluator; @@ -680,6 +684,9 @@ tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); + if (jiter.execute (stmt)) + return; + if (cmd || expr) { if (statement_context == function || statement_context == script) diff -r 22244a235fd0 -r f0499b0af646 src/pt-id.cc --- a/src/pt-id.cc Thu May 24 15:38:59 2012 -0400 +++ b/src/pt-id.cc Thu May 24 15:48:10 2012 -0600 @@ -63,7 +63,7 @@ if (error_state) return retval; - octave_value val = xsym ().find (); + octave_value val = sym->find (); if (val.is_defined ()) { @@ -114,7 +114,7 @@ octave_lvalue tree_identifier::lvalue (void) { - return octave_lvalue (&(xsym().varref ())); + return octave_lvalue (&(sym->varref ())); } tree_identifier * diff -r 22244a235fd0 -r f0499b0af646 src/pt-id.h --- a/src/pt-id.h Thu May 24 15:38:59 2012 -0400 +++ b/src/pt-id.h Thu May 24 15:48:10 2012 -0600 @@ -46,12 +46,12 @@ public: tree_identifier (int l = -1, int c = -1) - : tree_expression (l, c), sym (), scope (-1) { } + : tree_expression (l, c) { } tree_identifier (const symbol_table::symbol_record& s, int l = -1, int c = -1, symbol_table::scope_id sc = symbol_table::current_scope ()) - : tree_expression (l, c), sym (s), scope (sc) { } + : tree_expression (l, c), sym (s, sc) { } ~tree_identifier (void) { } @@ -63,9 +63,9 @@ // accessing it through sym so that this function may remain const. std::string name (void) const { return sym.name (); } - bool is_defined (void) { return xsym().is_defined (); } + bool is_defined (void) { return sym->is_defined (); } - virtual bool is_variable (void) { return xsym().is_variable (); } + virtual bool is_variable (void) { return sym->is_variable (); } virtual bool is_black_hole (void) { return false; } @@ -87,14 +87,14 @@ octave_value do_lookup (const octave_value_list& args = octave_value_list ()) { - return xsym().find (args); + return sym->find (args); } - void mark_global (void) { xsym().mark_global (); } + void mark_global (void) { sym->mark_global (); } - void mark_as_static (void) { xsym().init_persistent (); } + void mark_as_static (void) { sym->init_persistent (); } - void mark_as_formal_parameter (void) { xsym().mark_formal (); } + void mark_as_formal_parameter (void) { sym->mark_formal (); } // We really need to know whether this symbol referst to a variable // or a function, but we may not know that yet. @@ -114,28 +114,14 @@ void accept (tree_walker& tw); + symbol_table::symbol_record_ref symbol (void) const + { + return sym; + } private: // The symbol record that this identifier references. - symbol_table::symbol_record sym; - - symbol_table::scope_id scope; - - // A script may be executed in multiple scopes. If the last one was - // different from the one we are in now, update sym to be from the - // new scope. - symbol_table::symbol_record& xsym (void) - { - symbol_table::scope_id curr_scope = symbol_table::current_scope (); - - if (scope != curr_scope || ! sym.is_valid ()) - { - scope = curr_scope; - sym = symbol_table::insert (sym.name ()); - } - - return sym; - } + symbol_table::symbol_record_ref sym; // No copying! diff -r 22244a235fd0 -r f0499b0af646 src/pt-jit.cc --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/pt-jit.cc Thu May 24 15:48:10 2012 -0600 @@ -0,0 +1,1338 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#define __STDC_LIMIT_MACROS +#define __STDC_CONSTANT_MACROS + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "pt-jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "octave.h" +#include "ov-fcn-handle.h" +#include "ov-usr-fcn.h" +#include "ov-scalar.h" +#include "pt-all.h" + +// FIXME: Remove eventually +// For now we leave this in so people tell when JIT actually happens +static const bool debug_print = false; + +static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); + +static llvm::LLVMContext& context = llvm::getGlobalContext (); + +jit_typeinfo *jit_typeinfo::instance; + +// thrown when we should give up on JIT and interpret +class jit_fail_exception : public std::exception {}; + +static void +fail (void) +{ + throw jit_fail_exception (); +} + +// function that jit code calls +extern "C" void +octave_jit_print_any (const char *name, octave_base_value *obv) +{ + obv->print_with_name (octave_stdout, name, true); +} + +extern "C" void +octave_jit_print_double (const char *name, double value) +{ + // FIXME: We should avoid allocating a new octave_scalar each time + octave_value ov (value); + ov.print_with_name (octave_stdout, name); +} + +extern "C" octave_base_value* +octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs, + octave_base_value *rhs) +{ + octave_value olhs (lhs); + octave_value orhs (rhs); + octave_value result = do_binary_op (op, olhs, orhs); + octave_base_value *rep = result.internal_rep (); + rep->grab (); + return rep; +} + +extern "C" void +octave_jit_release_any (octave_base_value *obv) +{ + obv->release (); +} + +extern "C" octave_base_value * +octave_jit_grab_any (octave_base_value *obv) +{ + obv->grab (); + return obv; +} + +extern "C" double +octave_jit_cast_scalar_any (octave_base_value *obv) +{ + double ret = obv->double_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_scalar (double value) +{ + return new octave_scalar (value); +} + +// -------------------- jit_type -------------------- +llvm::Type * +jit_type::to_llvm_arg (void) const +{ + return llvm_type ? llvm_type->getPointerTo () : 0; +} + +// -------------------- jit_function -------------------- +void +jit_function::add_overload (const overload& func, + const std::vector& args) +{ + if (args.size () >= overloads.size ()) + overloads.resize (args.size () + 1); + + Array& over = overloads[args.size ()]; + dim_vector dv (over.dims ()); + Array idx = to_idx (args); + bool must_resize = false; + + if (dv.length () != idx.numel ()) + { + dv.resize (idx.numel ()); + must_resize = true; + } + + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (dv(i) <= idx(i)) + { + must_resize = true; + dv(i) = idx(i) + 1; + } + + if (must_resize) + over.resize (dv); + + over(idx) = func; +} + +const jit_function::overload& +jit_function::get_overload (const std::vector& types) const +{ + // FIXME: We should search for the next best overload on failure + static overload null_overload; + if (types.size () >= overloads.size ()) + return null_overload; + + for (size_t i =0; i < types.size (); ++i) + if (! types[i]) + return null_overload; + + const Array& over = overloads[types.size ()]; + dim_vector dv (over.dims ()); + Array idx = to_idx (types); + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (idx(i) >= dv(i)) + return null_overload; + + return over(idx); +} + +Array +jit_function::to_idx (const std::vector& types) const +{ + octave_idx_type numel = types.size (); + if (numel == 1) + numel = 2; + + Array idx (dim_vector (1, numel)); + for (octave_idx_type i = 0; i < static_cast (types.size ()); + ++i) + idx(i) = types[i]->type_id (); + + if (types.size () == 1) + { + idx(1) = idx(0); + idx(0) = 0; + } + + return idx; +} + +// -------------------- jit_typeinfo -------------------- +void +jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) +{ + instance = new jit_typeinfo (m, e); +} + +jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) + : module (m), engine (e), next_id (0) +{ + // FIXME: We should be registering types like in octave_value_typeinfo + ov_t = llvm::StructType::create (context, "octave_base_value"); + ov_t = ov_t->getPointerTo (); + + llvm::Type *dbl = llvm::Type::getDoubleTy (context); + llvm::Type *bool_t = llvm::Type::getInt1Ty (context); + llvm::Type *string_t = llvm::Type::getInt8Ty (context); + string_t = string_t->getPointerTo (); + llvm::Type *index_t = 0; + switch (sizeof(octave_idx_type)) + { + case 4: + index_t = llvm::Type::getInt32Ty (context); + break; + case 8: + index_t = llvm::Type::getInt64Ty (context); + break; + default: + assert (false && "Unrecognized index type size"); + } + + llvm::StructType *range_t = llvm::StructType::create (context, "range"); + std::vector range_contents (4, dbl); + range_contents[3] = index_t; + range_t->setBody (range_contents); + + // create types + any = new_type ("any", 0, ov_t); + scalar = new_type ("scalar", any, dbl); + range = new_type ("range", any, range_t); + string = new_type ("string", any, string_t); + boolean = new_type ("bool", any, bool_t); + index = new_type ("index", any, index_t); + + casts.resize (next_id + 1); + identities.resize (next_id + 1, 0); + + // any with anything is an any op + llvm::Function *fn; + llvm::Type *binary_op_type + = llvm::Type::getIntNTy (context, sizeof (octave_value::binary_op)); + llvm::Function *any_binary = create_function ("octave_jit_binary_any_any", + any->to_llvm (), binary_op_type, + any->to_llvm (), any->to_llvm ()); + engine->addGlobalMapping (any_binary, + reinterpret_cast(&octave_jit_binary_any_any)); + + binary_ops.resize (octave_value::num_binary_ops); + for (size_t i = 0; i < octave_value::num_binary_ops; ++i) + { + octave_value::binary_op op = static_cast (i); + std::string op_name = octave_value::binary_op_as_string (op); + binary_ops[i].stash_name ("binary" + op_name); + } + + for (int op = 0; op < octave_value::num_binary_ops; ++op) + { + llvm::Twine fn_name ("octave_jit_binary_any_any_"); + fn_name = fn_name + llvm::Twine (op); + fn = create_function (fn_name, any, any, any); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::APInt op_int(sizeof (octave_value::binary_op), op, + std::numeric_limits::is_signed); + llvm::Value *op_as_llvm = llvm::ConstantInt::get (binary_op_type, op_int); + llvm::Value *ret = builder.CreateCall3 (any_binary, + op_as_llvm, + fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + + jit_function::overload overload (fn, true, any, any, any); + for (octave_idx_type i = 0; i < next_id; ++i) + binary_ops[op].add_overload (overload); + } + + llvm::Type *void_t = llvm::Type::getVoidTy (context); + + // grab any + fn = create_function ("octave_jit_grab_any", any, any); + + engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_grab_any)); + grab_fn.add_overload (fn, false, any, any); + grab_fn.stash_name ("grab"); + + // grab scalar + fn = create_identity (scalar); + grab_fn.add_overload (fn, false, scalar, scalar); + + // release any + fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); + engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_release_any)); + release_fn.add_overload (fn, false, 0, any); + release_fn.stash_name ("release"); + + // release scalar + fn = create_identity (scalar); + release_fn.add_overload (fn, false, 0, scalar); + + // now for binary scalar operations + // FIXME: Finish all operations + add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); + add_binary_op (scalar, octave_value::op_mul, llvm::Instruction::FMul); + add_binary_op (scalar, octave_value::op_el_mul, llvm::Instruction::FMul); + + // FIXME: Warn if rhs is zero + add_binary_op (scalar, octave_value::op_div, llvm::Instruction::FDiv); + add_binary_op (scalar, octave_value::op_el_div, llvm::Instruction::FDiv); + + add_binary_fcmp (scalar, octave_value::op_lt, llvm::CmpInst::FCMP_ULT); + add_binary_fcmp (scalar, octave_value::op_le, llvm::CmpInst::FCMP_ULE); + add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ); + add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE); + add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); + add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); + + // now for printing functions + print_fn.stash_name ("print"); + add_print (any, reinterpret_cast (&octave_jit_print_any)); + add_print (scalar, reinterpret_cast (&octave_jit_print_double)); + + // bounds check for for loop + fn = create_function ("octave_jit_simple_for_range", boolean, range, index); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *nelem + = builder.CreateExtractValue (fn->arg_begin (), 3); + // llvm::Value *idx = builder.CreateLoad (++fn->arg_begin ()); + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *ret = builder.CreateICmpULT (idx, nelem); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_check.add_overload (fn, false, boolean, range, index); + + // increment for for loop + fn = create_function ("octave_jit_imple_for_range_incr", index, index); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + llvm::Value *idx = fn->arg_begin (); + llvm::Value *ret = builder.CreateAdd (idx, one); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_incr.add_overload (fn, false, index, index); + + // index variabe for for loop + fn = create_function ("octave_jit_simple_for_idx", scalar, range, index); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *didx = builder.CreateUIToFP (idx, dbl); + llvm::Value *rng = fn->arg_begin (); + llvm::Value *base = builder.CreateExtractValue (rng, 0); + llvm::Value *inc = builder.CreateExtractValue (rng, 2); + + llvm::Value *ret = builder.CreateFMul (didx, inc); + ret = builder.CreateFAdd (base, ret); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_index.add_overload (fn, false, scalar, range, index); + + // logically true + // FIXME: Check for NaN + fn = create_function ("octave_logically_true_scalar", boolean, scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantFP::get (scalar->to_llvm (), 0); + llvm::Value *ret = builder.CreateFCmpUNE (fn->arg_begin (), zero); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + logically_true.add_overload (fn, true, boolean, scalar); + + fn = create_function ("octave_logically_true_bool", boolean, boolean); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + builder.CreateRet (fn->arg_begin ()); + llvm::verifyFunction (*fn); + logically_true.add_overload (fn, false, boolean, boolean); + logically_true.stash_name ("logically_true"); + + casts[any->type_id ()].stash_name ("(any)"); + casts[scalar->type_id ()].stash_name ("(scalar)"); + + // cast any <- scalar + fn = create_function ("octave_jit_cast_any_scalar", any, scalar); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_scalar)); + casts[any->type_id ()].add_overload (fn, false, any, scalar); + + // cast scalar <- any + fn = create_function ("octave_jit_cast_scalar_any", scalar, any); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_scalar_any)); + casts[scalar->type_id ()].add_overload (fn, false, scalar, any); + + // cast any <- any + fn = create_identity (any); + casts[any->type_id ()].add_overload (fn, false, any, any); + + // cast scalar <- scalar + fn = create_identity (scalar); + casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar); +} + +void +jit_typeinfo::add_print (jit_type *ty, void *call) +{ + std::stringstream name; + name << "octave_jit_print_" << ty->name (); + + llvm::Type *void_t = llvm::Type::getVoidTy (context); + llvm::Function *fn = create_function (name.str (), void_t, + llvm::Type::getInt8PtrTy (context), + ty->to_llvm ()); + engine->addGlobalMapping (fn, call); + + jit_function::overload ol (fn, false, 0, string, ty); + print_fn.add_overload (ol); +} + +// FIXME: cp between add_binary_op, add_binary_icmp, and add_binary_fcmp +void +jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::Function *fn = create_function (fname.str (), ty, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::Instruction::BinaryOps temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateBinOp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_function::overload ol(fn, false, ty, ty, ty); + binary_ops[op].add_overload (ol); +} + +void +jit_typeinfo::add_binary_icmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateICmp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_function::overload ol (fn, false, boolean, ty, ty); + binary_ops[op].add_overload (ol); +} + +void +jit_typeinfo::add_binary_fcmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateFCmp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_function::overload ol (fn, false, boolean, ty, ty); + binary_ops[op].add_overload (ol); +} + +llvm::Function * +jit_typeinfo::create_function (const llvm::Twine& name, llvm::Type *ret, + const std::vector& args) +{ + llvm::FunctionType *ft = llvm::FunctionType::get (ret, args, false); + llvm::Function *fn = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + name, module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + return fn; +} + +llvm::Function * +jit_typeinfo::create_identity (jit_type *type) +{ + size_t id = type->type_id (); + if (id >= identities.size ()) + identities.resize (id + 1, 0); + + if (! identities[id]) + { + llvm::Function *fn = create_function ("id", type, type); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + builder.CreateRet (fn->arg_begin ()); + llvm::verifyFunction (*fn); + identities[id] = fn; + } + + return identities[id]; +} + +jit_type * +jit_typeinfo::do_type_of (const octave_value &ov) const +{ + if (ov.is_undefined () || ov.is_function ()) + return 0; + + if (ov.is_double_type () && ov.is_real_scalar ()) + return get_scalar (); + + if (ov.is_range ()) + return get_range (); + + return get_any (); +} + +jit_type* +jit_typeinfo::new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type) +{ + jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); + id_to_type.push_back (ret); + return ret; +} + +// -------------------- jit_block -------------------- +llvm::BasicBlock * +jit_block::to_llvm (void) const +{ + return llvm::cast (llvm_value); +} + +// -------------------- jit_call -------------------- +bool +jit_call::infer (void) +{ + // FIXME explain algorithm + jit_type *current = type (); + for (size_t i = 0; i < argument_count (); ++i) + { + jit_type *arg_type = argument_type (i); + jit_type *todo = jit_typeinfo::difference (arg_type, already_infered[i]); + if (todo) + { + already_infered[i] = todo; + jit_type *fresult = mfunction.get_result (already_infered); + current = jit_typeinfo::tunion (current, fresult); + already_infered[i] = arg_type; + } + } + + if (current != type ()) + { + stash_type (current); + return true; + } + + return false; +} + +// -------------------- jit_convert -------------------- +jit_convert::jit_convert (llvm::Module *module, tree &tee) +{ + jit_instruction::reset_ids (); + + entry_block = new jit_block ("entry"); + blocks.push_back (entry_block); + block = new jit_block ("body"); + blocks.push_back (block); + + final_block = new jit_block ("final"); + visit (tee); + blocks.push_back (final_block); + + entry_block->append (new jit_break (block)); + block->append (new jit_break (final_block)); + + for (variable_map::iterator iter = variables.begin (); + iter != variables.end (); ++iter) + final_block->append (new jit_store_argument (iter->first, iter->second)); + + // FIXME: Maybe we should remove dead code here? + + // initialize the worklist to instructions derived from constants + for (std::list::iterator iter = constants.begin (); + iter != constants.end (); ++iter) + append_users (*iter); + + // FIXME: Describe algorithm here + while (worklist.size ()) + { + jit_instruction *next = worklist.front (); + worklist.pop_front (); + + if (next->infer ()) + append_users (next); + } + + if (debug_print) + { + std::cout << "-------------------- Compiling tree --------------------\n"; + std::cout << tee.str_print_code () << std::endl; + std::cout << "-------------------- octave jit ir --------------------\n"; + for (std::list::iterator iter = blocks.begin (); + iter != blocks.end (); ++iter) + (*iter)->print (std::cout, 0); + std::cout << std::endl; + } + + convert_llvm to_llvm; + function = to_llvm.convert (module, arguments, blocks, constants); + + if (debug_print) + { + std::cout << "-------------------- llvm ir --------------------"; + llvm::raw_os_ostream llvm_cout (std::cout); + function->print (llvm_cout); + std::cout << std::endl; + } +} + +void +jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&) +{ + fail (); +} + +void +jit_convert::visit_argument_list (tree_argument_list&) +{ + fail (); +} + +void +jit_convert::visit_binary_expression (tree_binary_expression& be) +{ + if (be.op_type () >= octave_value::num_binary_ops) + // this is the case for bool_or and bool_and + fail (); + + tree_expression *lhs = be.lhs (); + jit_value *lhsv = visit (lhs); + + tree_expression *rhs = be.rhs (); + jit_value *rhsv = visit (rhs); + + const jit_function& fn = jit_typeinfo::binary_op (be.op_type ()); + result = block->append (new jit_call (fn, lhsv, rhsv)); +} + +void +jit_convert::visit_break_command (tree_break_command&) +{ + fail (); +} + +void +jit_convert::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +jit_convert::visit_continue_command (tree_continue_command&) +{ + fail (); +} + +void +jit_convert::visit_global_command (tree_global_command&) +{ + fail (); +} + +void +jit_convert::visit_persistent_command (tree_persistent_command&) +{ + fail (); +} + +void +jit_convert::visit_decl_elt (tree_decl_elt&) +{ + fail (); +} + +void +jit_convert::visit_decl_init_list (tree_decl_init_list&) +{ + fail (); +} + +void +jit_convert::visit_simple_for_command (tree_simple_for_command&) +{ + fail (); +} + +void +jit_convert::visit_complex_for_command (tree_complex_for_command&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_script (octave_user_script&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_function_def (tree_function_def&) +{ + fail (); +} + +void +jit_convert::visit_identifier (tree_identifier& ti) +{ + std::string name = ti.name (); + variable_map::iterator iter = variables.find (name); + jit_value *var; + if (iter == variables.end ()) + { + octave_value var_value = ti.do_lookup (); + jit_type *var_type = jit_typeinfo::type_of (var_value); + var = entry_block->append (new jit_extract_argument (var_type, name)); + constants.push_back (var); + bounds.push_back (std::make_pair (var_type, name)); + variables[name] = var; + arguments.push_back (std::make_pair (name, true)); + } + else + var = iter->second; + + const jit_function& fn = jit_typeinfo::grab (); + result = block->append (new jit_call (fn, var)); +} + +void +jit_convert::visit_if_clause (tree_if_clause&) +{ + fail (); +} + +void +jit_convert::visit_if_command (tree_if_command&) +{ + fail (); +} + +void +jit_convert::visit_if_command_list (tree_if_command_list&) +{ + fail (); +} + +void +jit_convert::visit_index_expression (tree_index_expression&) +{ + fail (); +} + +void +jit_convert::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +jit_convert::visit_cell (tree_cell&) +{ + fail (); +} + +void +jit_convert::visit_multi_assignment (tree_multi_assignment&) +{ + fail (); +} + +void +jit_convert::visit_no_op_command (tree_no_op_command&) +{ + fail (); +} + +void +jit_convert::visit_constant (tree_constant& tc) +{ + octave_value v = tc.rvalue1 (); + if (v.is_real_scalar () && v.is_double_type ()) + { + double dv = v.double_value (); + result = get_scalar (dv); + } + else if (v.is_range ()) + fail (); + else + fail (); +} + +void +jit_convert::visit_fcn_handle (tree_fcn_handle&) +{ + fail (); +} + +void +jit_convert::visit_parameter_list (tree_parameter_list&) +{ + fail (); +} + +void +jit_convert::visit_postfix_expression (tree_postfix_expression&) +{ + fail (); +} + +void +jit_convert::visit_prefix_expression (tree_prefix_expression&) +{ + fail (); +} + +void +jit_convert::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +jit_convert::visit_return_list (tree_return_list&) +{ + fail (); +} + +void +jit_convert::visit_simple_assignment (tree_simple_assignment& tsa) +{ + // resolve rhs + tree_expression *rhs = tsa.right_hand_side (); + jit_value *rhsv = visit (rhs); + + // resolve lhs + tree_expression *lhs = tsa.left_hand_side (); + if (! lhs->is_identifier ()) + fail (); + + std::string lhs_name = lhs->name (); + do_assign (lhs_name, rhsv, tsa.print_result ()); + result = rhsv; + + if (jit_instruction *instr = dynamic_cast(rhsv)) + instr->stash_tag (lhs_name); +} + +void +jit_convert::visit_statement (tree_statement& stmt) +{ + tree_command *cmd = stmt.command (); + tree_expression *expr = stmt.expression (); + + if (cmd) + visit (cmd); + else + { + // stolen from tree_evaluator::visit_statement + bool do_bind_ans = false; + + if (expr->is_identifier ()) + { + tree_identifier *id = dynamic_cast (expr); + + do_bind_ans = (! id->is_variable ()); + } + else + do_bind_ans = (! expr->is_assignment_expression ()); + + jit_value *expr_result = visit (expr); + + if (do_bind_ans) + do_assign ("ans", expr_result, expr->print_result ()); + else if (expr->is_identifier () && expr->print_result ()) + { + // FIXME: ugly hack, we need to come up with a way to pass + // nargout to visit_identifier + const jit_function& fn = jit_typeinfo::print_value (); + jit_const_string *name = get_string (expr->name ()); + block->append (new jit_call (fn, name, expr_result)); + } + } +} + +void +jit_convert::visit_statement_list (tree_statement_list&) +{ + fail (); +} + +void +jit_convert::visit_switch_case (tree_switch_case&) +{ + fail (); +} + +void +jit_convert::visit_switch_case_list (tree_switch_case_list&) +{ + fail (); +} + +void +jit_convert::visit_switch_command (tree_switch_command&) +{ + fail (); +} + +void +jit_convert::visit_try_catch_command (tree_try_catch_command&) +{ + fail (); +} + +void +jit_convert::visit_unwind_protect_command (tree_unwind_protect_command&) +{ + fail (); +} + +void +jit_convert::visit_while_command (tree_while_command&) +{ + fail (); +} + +void +jit_convert::visit_do_until_command (tree_do_until_command&) +{ + fail (); +} + +void +jit_convert::do_assign (const std::string& lhs, jit_value *rhs, bool print) +{ + variable_map::iterator iter = variables.find (lhs); + if (iter == variables.end ()) + arguments.push_back (std::make_pair (lhs, false)); + else + { + const jit_function& fn = jit_typeinfo::release (); + block->append (new jit_call (fn, iter->second)); + } + + variables[lhs] = rhs; + + if (print) + { + const jit_function& fn = jit_typeinfo::print_value (); + jit_const_string *name = get_string (lhs); + block->append (new jit_call (fn, name, rhs)); + } +} + +jit_value * +jit_convert::visit (tree& tee) +{ + result = 0; + tee.accept (*this); + + jit_value *ret = result; + result = 0; + return ret; +} + +// -------------------- jit_convert::convert_llvm -------------------- +llvm::Function * +jit_convert::convert_llvm::convert (llvm::Module *module, + const std::vector >& args, + const std::list& blocks, + const std::list& constants) +{ + jit_type *any = jit_typeinfo::get_any (); + + // argument is an array of octave_base_value*, or octave_base_value** + llvm::Type *arg_type = any->to_llvm (); // this is octave_base_value* + arg_type = arg_type->getPointerTo (); + llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context), + arg_type, false); + llvm::Function *function = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + "foobar", module); + + try + { + llvm::BasicBlock *prelude = llvm::BasicBlock::Create (context, "prelude", + function); + builder.SetInsertPoint (prelude); + + llvm::Value *arg = function->arg_begin (); + for (size_t i = 0; i < args.size (); ++i) + { + llvm::Value *loaded_arg = builder.CreateConstInBoundsGEP1_32 (arg, i); + arguments[args[i].first] = loaded_arg; + } + + // we need to generate llvm values for constants, as these don't appear in + // a block + for (std::list::const_iterator iter = constants.begin (); + iter != constants.end (); ++iter) + { + jit_value *constant = *iter; + if (! dynamic_cast (constant)) + visit (constant); + } + + std::list::const_iterator biter; + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block *jblock = *biter; + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, jblock->name (), + function); + jblock->stash_llvm (block); + } + + jit_block *first = *blocks.begin (); + builder.CreateBr (first->to_llvm ()); + + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + visit (*biter); + + builder.CreateRetVoid (); + } catch (const jit_fail_exception&) + { + function->eraseFromParent (); + throw; + } + + llvm::verifyFunction (*function); + + return function; +} + +void +jit_convert::convert_llvm::visit_const_string (jit_const_string& cs) +{ + cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ())); +} + +void +jit_convert::convert_llvm::visit_const_scalar (jit_const_scalar& cs) +{ + llvm::Type *dbl = llvm::Type::getDoubleTy (context); + cs.stash_llvm (llvm::ConstantFP::get (dbl, cs.value ())); +} + +void +jit_convert::convert_llvm::visit_block (jit_block& b) +{ + llvm::BasicBlock *block = b.to_llvm (); + builder.SetInsertPoint (block); + for (jit_block::iterator iter = b.begin (); iter != b.end (); ++iter) + visit (*iter); +} + +void +jit_convert::convert_llvm::visit_break (jit_break& b) +{ + builder.CreateBr (b.sucessor_llvm ()); +} + +void +jit_convert::convert_llvm::visit_cond_break (jit_cond_break& cb) +{ + llvm::Value *cond = cb.cond_llvm (); + builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); +} + +void +jit_convert::convert_llvm::visit_call (jit_call& call) +{ + const jit_function::overload& ol = call.overload (); + if (! ol.function) + fail (); + + std::vector args (call.argument_count ()); + for (size_t i = 0; i < call.argument_count (); ++i) + args[i] = call.argument_llvm (i); + + call.stash_llvm (builder.CreateCall (ol.function, args)); +} + +void +jit_convert::convert_llvm::visit_extract_argument (jit_extract_argument& extract) +{ + const jit_function::overload& ol = extract.overload (); + if (! ol.function) + fail (); + + llvm::Value *arg = arguments[extract.tag ()]; + arg = builder.CreateLoad (arg); + extract.stash_llvm (builder.CreateCall (ol.function, arg)); +} + +void +jit_convert::convert_llvm::visit_store_argument (jit_store_argument& store) +{ + llvm::Value *arg_value = store.result_llvm (); + const jit_function::overload& ol = store.overload (); + if (! ol.function) + fail (); + + arg_value = builder.CreateCall (ol.function, arg_value); + + llvm::Value *arg = arguments[store.tag ()]; + store.stash_llvm (builder.CreateStore (arg_value, arg)); +} + +// -------------------- tree_jit -------------------- + +tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) +{ + llvm::InitializeNativeTarget (); + module = new llvm::Module ("octave", context); +} + +tree_jit::~tree_jit (void) +{} + +bool +tree_jit::execute (tree& cmd) +{ + if (! initialize ()) + return false; + + compiled_map::iterator iter = compiled.find (&cmd); + jit_info *jinfo = 0; + if (iter != compiled.end ()) + { + jinfo = iter->second; + if (! jinfo->match ()) + { + delete jinfo; + jinfo = 0; + } + } + + if (! jinfo) + { + jinfo = new jit_info (*this, cmd); + compiled[&cmd] = jinfo; + } + + return jinfo->execute (); +} + +bool +tree_jit::initialize (void) +{ + if (engine) + return true; + + // sometimes this fails pre main + engine = llvm::ExecutionEngine::createJIT (module); + + if (! engine) + return false; + + module_pass_manager = new llvm::PassManager (); + module_pass_manager->add (llvm::createAlwaysInlinerPass ()); + + pass_manager = new llvm::FunctionPassManager (module); + pass_manager->add (new llvm::TargetData(*engine->getTargetData ())); + pass_manager->add (llvm::createBasicAliasAnalysisPass ()); + pass_manager->add (llvm::createPromoteMemoryToRegisterPass ()); + pass_manager->add (llvm::createInstructionCombiningPass ()); + pass_manager->add (llvm::createReassociatePass ()); + pass_manager->add (llvm::createGVNPass ()); + pass_manager->add (llvm::createCFGSimplificationPass ()); + pass_manager->doInitialization (); + + jit_typeinfo::initialize (module, engine); + + return true; +} + + +void +tree_jit::optimize (llvm::Function *fn) +{ + module_pass_manager->run (*module); + pass_manager->run (*fn); +} + +// -------------------- jit_info -------------------- +jit_info::jit_info (tree_jit& tjit, tree& tee) + : engine (tjit.get_engine ()) +{ + llvm::Function *fun = 0; + try + { + jit_convert conv (tjit.get_module (), tee); + fun = conv.get_function (); + arguments = conv.get_arguments (); + bounds = conv.get_bounds (); + } + catch (const jit_fail_exception&) + {} + + if (! fun) + { + function = 0; + return; + } + + tjit.optimize (fun); + + if (debug_print) + { + std::cout << "-------------------- optimized llvm ir --------------------\n"; + llvm::raw_os_ostream llvm_cout (std::cout); + fun->print (llvm_cout); + std::cout << std::endl; + } + + function = reinterpret_cast(engine->getPointerToFunction (fun)); +} + +bool +jit_info::execute (void) const +{ + if (! function) + return false; + + std::vector real_arguments (arguments.size ()); + for (size_t i = 0; i < arguments.size (); ++i) + { + if (arguments[i].second) + { + octave_value current = symbol_table::varval (arguments[i].first); + octave_base_value *obv = current.internal_rep (); + obv->grab (); + real_arguments[i] = obv; + } + } + + function (&real_arguments[0]); + + for (size_t i = 0; i < arguments.size (); ++i) + symbol_table::varref (arguments[i].first) = real_arguments[i]; + + return true; +} + +bool +jit_info::match (void) const +{ + if (! function) + return true; + + for (size_t i = 0; i < bounds.size (); ++i) + { + const std::string& arg_name = bounds[i].second; + octave_value value = symbol_table::varval (arg_name); + jit_type *type = jit_typeinfo::type_of (value); + + // FIXME: Check for a parent relationship + if (type != bounds[i].first) + return false; + } + + return true; +} diff -r 22244a235fd0 -r f0499b0af646 src/pt-jit.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/pt-jit.h Thu May 24 15:48:10 2012 -0600 @@ -0,0 +1,1244 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#if !defined (octave_tree_jit_h) +#define octave_tree_jit_h 1 + +#include +#include +#include +#include +#include + +#include "Array.h" +#include "Range.h" +#include "pt-walk.h" +#include "symtab.h" + +// -------------------- Current status -------------------- +// Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized. +// However, there is no warning emitted on divide by 0. For example, +// a = 5; +// b = a * 5 + a; +// +// For other types all binary operations are compiled but not optimized. For +// example, +// a = [1 2 3] +// b = a + a; +// will compile to do_binary_op (a, a). +// +// for loops and if statements no longer compile! This is because work has been +// done to introduce a new lower level IR for octave. The low level IR looks +// a lot like llvm's IR, but it makes it much easier to infer types. You can set +// debug_print to true in pt-jit.cc to view the IRs that are created. +// +// The octave low level IR is a linear IR, it works by converting everything to +// calls to jit_functions. This turns expressions like c = a + b into +// c = call binary+ (a, b) +// The jit_functions contain information about overloads for differnt types. For +// example, if we know a and b are scalars, then c must also be a scalar. +// +// You will currently see a LARGE slowdown, as every statement is compiled +// seperatly! +// +// TODO: +// 1. Support for loops +// 2. Support if statements +// 3. Cleanup/documentation +// 4. ... +// --------------------------------------------------------- + + +// we don't want to include llvm headers here, as they require __STDC_LIMIT_MACROS +// and __STDC_CONSTANT_MACROS be defined in the entire compilation unit +namespace llvm +{ + class Value; + class Module; + class FunctionPassManager; + class PassManager; + class ExecutionEngine; + class Function; + class BasicBlock; + class LLVMContext; + class Type; + class Twine; +} + +class octave_base_value; +class octave_value; +class tree; + +// jit_range is compatable with the llvm range structure +struct +jit_range +{ + jit_range (void) {} + + jit_range (const Range& from) : base (from.base ()), limit (from.limit ()), + inc (from.inc ()), nelem (from.nelem ()) + {} + + operator Range () const + { + return Range (base, limit, inc); + } + + double base; + double limit; + double inc; + octave_idx_type nelem; +}; + +// Used to keep track of estimated (infered) types during JIT. This is a +// hierarchical type system which includes both concrete and abstract types. +// +// Current, we only support any and scalar types. If we can't figure out what +// type a variable is, we assign it the any type. This allows us to generate +// code even for the case of poor type inference. +class +jit_type +{ +public: + jit_type (const std::string& aname, jit_type *aparent, llvm::Type *allvm_type, + int aid) : + mname (aname), mparent (aparent), llvm_type (allvm_type), mid (aid), + mdepth (aparent ? aparent->mdepth + 1 : 0) + {} + + // a user readable type name + const std::string& name (void) const { return mname; } + + // a unique id for the type + int type_id (void) const { return mid; } + + // An abstract base type, may be null + jit_type *parent (void) const { return mparent; } + + // convert to an llvm type + llvm::Type *to_llvm (void) const { return llvm_type; } + + // how this type gets passed as a function argument + llvm::Type *to_llvm_arg (void) const; + + size_t depth (void) const { return mdepth; } +private: + std::string mname; + jit_type *mparent; + llvm::Type *llvm_type; + int mid; + size_t mdepth; +}; + +// seperate print function to allow easy printing if type is null +static std::ostream& jit_print (std::ostream& os, jit_type *atype) +{ + if (! atype) + return os << "null"; + return os << atype->name (); +} + +// Keeps track of overloads for a builtin function. Used for both type inference +// and code generation. +class +jit_function +{ +public: + struct overload + { + overload (void) : function (0), can_error (true), result (0) {} + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) : + function (f), can_error (e), result (r), arguments (1) + { + arguments[0] = arg0; + } + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) : function (f), can_error (e), result (r), + arguments (2) + { + arguments[0] = arg0; + arguments[1] = arg1; + } + + llvm::Function *function; + bool can_error; + jit_type *result; + std::vector arguments; + }; + + void add_overload (const overload& func) + { + add_overload (func, func.arguments); + } + + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) + { + overload ol (f, e, r, arg0); + add_overload (ol); + } + + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) + { + overload ol (f, e, r, arg0, arg1); + add_overload (ol); + } + + void add_overload (const overload& func, + const std::vector& args); + + const overload& get_overload (const std::vector& types) const; + + const overload& get_overload (jit_type *arg0) const + { + std::vector types (1); + types[0] = arg0; + return get_overload (types); + } + + const overload& get_overload (jit_type *arg0, jit_type *arg1) const + { + std::vector types (2); + types[0] = arg0; + types[1] = arg1; + return get_overload (types); + } + + jit_type *get_result (const std::vector& types) const + { + const overload& temp = get_overload (types); + return temp.result; + } + + jit_type *get_result (jit_type *arg0, jit_type *arg1) const + { + const overload& temp = get_overload (arg0, arg1); + return temp.result; + } + + const std::string& name (void) const { return mname; } + + void stash_name (const std::string& aname) { mname = aname; } +private: + Array to_idx (const std::vector& types) const; + + std::vector > overloads; + + std::string mname; +}; + +// Get information and manipulate jit types. +class +jit_typeinfo +{ +public: + static void initialize (llvm::Module *m, llvm::ExecutionEngine *e); + + static jit_type *tunion (jit_type *lhs, jit_type *rhs) + { + return instance->do_union (lhs, rhs); + } + + static jit_type *difference (jit_type *lhs, jit_type *rhs) + { + return instance->do_difference (lhs, rhs); + } + + static jit_type *get_any (void) { return instance->any; } + + static jit_type *get_scalar (void) { return instance->scalar; } + + static jit_type *get_range (void) { return instance->range; } + + static jit_type *get_string (void) { return instance->string; } + + static jit_type *get_bool (void) { return instance->boolean; } + + static jit_type *get_index (void) { return instance->index; } + + static jit_type *type_of (const octave_value& ov) + { + return instance->do_type_of (ov); + } + + static const jit_function& binary_op (int op) + { + return instance->do_binary_op (op); + } + + static const jit_function& grab (void) { return instance->grab_fn; } + + static const jit_function& release (void) + { + return instance->release_fn; + } + + static const jit_function& print_value (void) + { + return instance->print_fn; + } + + static const jit_function& cast (jit_type *result) + { + return instance->do_cast (result); + } + + static const jit_function::overload& cast (jit_type *to, jit_type *from) + { + return instance->do_cast (to, from); + } +private: + jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); + + // FIXME: Do these methods really need to be in jit_typeinfo? + jit_type *do_union (jit_type *lhs, jit_type *rhs) + { + // FIXME: Actually introduce a union type + + // empty case + if (! lhs) + return rhs; + + if (! rhs) + return lhs; + + // check for a shared parent + while (lhs != rhs) + { + if (lhs->depth () > rhs->depth ()) + lhs = lhs->parent (); + else if (lhs->depth () < rhs->depth ()) + rhs = rhs->parent (); + else + { + // we MUST have depth > 0 as any is the base type of everything + do + { + lhs = lhs->parent (); + rhs = rhs->parent (); + } + while (lhs != rhs); + } + } + + return lhs; + } + + jit_type *do_difference (jit_type *lhs, jit_type *) + { + // FIXME: Maybe we can do something smarter? + return lhs; + } + + jit_type *do_type_of (const octave_value &ov) const; + + const jit_function& do_binary_op (int op) const + { + assert (static_cast(op) < binary_ops.size ()); + return binary_ops[op]; + } + + const jit_function& do_cast (jit_type *to) + { + static jit_function null_function; + if (! to) + return null_function; + + size_t id = to->type_id (); + if (id >= casts.size ()) + return null_function; + return casts[id]; + } + + const jit_function::overload& do_cast (jit_type *to, jit_type *from) + { + return do_cast (to).get_overload (from); + } + + jit_type *new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type); + + + void add_print (jit_type *ty, void *call); + + void add_binary_op (jit_type *ty, int op, int llvm_op); + + void add_binary_icmp (jit_type *ty, int op, int llvm_op); + + void add_binary_fcmp (jit_type *ty, int op, int llvm_op); + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + llvm::Type *arg0) + { + std::vector args (1, arg0); + return create_function (name, ret, args); + } + + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + jit_type *arg0) + { + return create_function (name, ret->to_llvm (), arg0->to_llvm ()); + } + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + llvm::Type *arg0, llvm::Type *arg1) + { + std::vector args (2); + args[0] = arg0; + args[1] = arg1; + return create_function (name, ret, args); + } + + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + jit_type *arg0, jit_type *arg1) + { + return create_function (name, ret->to_llvm (), arg0->to_llvm (), + arg1->to_llvm ()); + } + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + llvm::Type *arg0, llvm::Type *arg1, + llvm::Type *arg2) + { + std::vector args (3); + args[0] = arg0; + args[1] = arg1; + args[2] = arg2; + return create_function (name, ret, args); + } + + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + jit_type *arg0, jit_type *arg1, + jit_type *arg2) + { + return create_function (name, ret->to_llvm (), arg0->to_llvm (), + arg1->to_llvm (), arg2->to_llvm ()); + } + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + const std::vector& args); + + llvm::Function *create_identity (jit_type *type); + + static jit_typeinfo *instance; + + llvm::Module *module; + llvm::ExecutionEngine *engine; + int next_id; + + llvm::Type *ov_t; + + std::vector id_to_type; + jit_type *any; + jit_type *scalar; + jit_type *range; + jit_type *string; + jit_type *boolean; + jit_type *index; + + std::vector binary_ops; + jit_function grab_fn; + jit_function release_fn; + jit_function print_fn; + jit_function simple_for_check; + jit_function simple_for_incr; + jit_function simple_for_index; + jit_function logically_true; + + // type id -> cast function TO that type + std::vector casts; + + // type id -> identity function + std::vector identities; +}; + +// The low level octave jit ir +// this ir is close to llvm, but contains information for doing type inference. +// We convert the octave parse tree to this IR directly. + +#define JIT_VISIT_IR_CLASSES \ + JIT_METH(const_string); \ + JIT_METH(const_scalar); \ + JIT_METH(block); \ + JIT_METH(break); \ + JIT_METH(cond_break); \ + JIT_METH(call); \ + JIT_METH(extract_argument); \ + JIT_METH(store_argument) + + +#define JIT_METH(clname) class jit_ ## clname +JIT_VISIT_IR_CLASSES; +#undef JIT_METH + +class +jit_ir_walker +{ +public: + virtual ~jit_ir_walker () {} + +#define JIT_METH(clname) \ + virtual void visit_ ## clname (jit_ ## clname&) = 0 + + JIT_VISIT_IR_CLASSES; + +#undef JIT_METH +}; + +class jit_use; + +class +jit_value +{ + friend class jit_use; +public: + jit_value (void) : llvm_value (0), ty (0), use_head (0) {} + + virtual ~jit_value (void) {} + + jit_type *type () const { return ty; } + + void stash_type (jit_type *new_ty) { ty = new_ty; } + + jit_use *first_use (void) const { return use_head; } + + size_t use_count (void) const { return myuse_count; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) = 0; + + virtual std::ostream& short_print (std::ostream& os) + { return print (os); } + + virtual void accept (jit_ir_walker& walker) = 0; + + llvm::Value *to_llvm (void) const + { + return llvm_value; + } + + void stash_llvm (llvm::Value *compiled) + { + llvm_value = compiled; + } +protected: + std::ostream& print_indent (std::ostream& os, size_t indent) + { + for (size_t i = 0; i < indent; ++i) + os << "\t"; + return os; + } + + llvm::Value *llvm_value; +private: + jit_type *ty; + jit_use *use_head; + size_t myuse_count; +}; + +// defnie accept methods for subclasses +#define JIT_VALUE_ACCEPT(clname) \ + virtual void accept (jit_ir_walker& walker) \ + { \ + walker.visit_ ## clname (*this); \ + } + +class +jit_const_string : public jit_value +{ +public: + jit_const_string (const std::string& v) : val (v) + { + stash_type (jit_typeinfo::get_string ()); + } + + const std::string& value (void) const { return val; } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + return print_indent (os, indent) << "string: \"" << val << "\""; + } + + JIT_VALUE_ACCEPT (const_string) +private: + std::string val; +}; + +class +jit_const_scalar : public jit_value +{ +public: + jit_const_scalar (double avalue) : mvalue (avalue) + { + stash_type (jit_typeinfo::get_scalar ()); + } + + double value (void) const { return mvalue; } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + return print_indent (os, indent) << "scalar: \"" << mvalue << "\""; + } + + JIT_VALUE_ACCEPT (const_scalar) +private: + double mvalue; +}; + +class jit_instruction; + +class +jit_use +{ +public: + jit_use (void) : used (0), next_use (0), prev_use (0) {} + + ~jit_use (void) { remove (); } + + jit_value *value (void) const { return used; } + + size_t index (void) const { return idx; } + + jit_instruction *user (void) const { return usr; } + + void stash_value (jit_value *new_value, jit_instruction *u = 0, + size_t use_idx = -1) + { + remove (); + + used = new_value; + + if (used) + { + if (used->use_head) + { + used->use_head->prev_use = this; + next_use = used->use_head; + } + + used->use_head = this; + ++used->myuse_count; + } + + idx = use_idx; + usr = u; + } + + jit_use *next (void) const { return next_use; } + + jit_use *prev (void) const { return prev_use; } +private: + void remove (void) + { + if (used) + { + if (this == used->use_head) + used->use_head = next_use; + + if (prev_use) + prev_use->next_use = next_use; + + if (next_use) + next_use->prev_use = prev_use; + + next_use = prev_use = 0; + --used->myuse_count; + } + } + + jit_value *used; + jit_use *next_use; + jit_use *prev_use; + jit_instruction *usr; + size_t idx; +}; + +class +jit_instruction : public jit_value +{ +public: + // FIXME: this code could be so much pretier with varadic templates... +#define JIT_EXTRACT_ARG(idx) arguments[idx].stash_value (arg ## idx, this, idx) + + jit_instruction (void) : id (next_id ()) + { + } + + jit_instruction (jit_value *arg0) + : already_infered (1, reinterpret_cast(0)), arguments (1), + id (next_id ()) + { + JIT_EXTRACT_ARG (0); + } + + jit_instruction (jit_value *arg0, jit_value *arg1) + : already_infered (2, reinterpret_cast(0)), arguments (2), + id (next_id ()) + { + JIT_EXTRACT_ARG (0); + JIT_EXTRACT_ARG (1); + } + + jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2) + : already_infered (3, reinterpret_cast(0)), arguments (3), + id (next_id ()) + { + JIT_EXTRACT_ARG (0); + JIT_EXTRACT_ARG (1); + JIT_EXTRACT_ARG (2); + } + +#undef JIT_EXTRACT_ARG + + static void reset_ids (void) + { + next_id (true); + } + + jit_value *argument (size_t i) const + { + return arguments[i].value (); + } + + llvm::Value *argument_llvm (size_t i) const + { + return arguments[i].value ()->to_llvm (); + } + + jit_type *argument_type (size_t i) const + { + return arguments[i].value ()->type (); + } + + size_t argument_count (void) const + { + return arguments.size (); + } + + // argument types which have been infered already + const std::vector& argument_types (void) const + { return already_infered; } + + virtual bool infer (void) { return false; } + + virtual std::ostream& short_print (std::ostream& os) + { + if (mtag.empty ()) + jit_print (os, type ()) << ": #" << id; + else + jit_print (os, type ()) << ": " << mtag << "." << id; + + return os; + } + + const std::string& tag (void) const { return mtag; } + + void stash_tag (const std::string& atag) { mtag = atag; } +protected: + std::vector already_infered; +private: + static size_t next_id (bool reset = false) + { + static size_t ret = 0; + if (reset) + return ret = 0; + + return ret++; + } + + std::vector arguments; // DO NOT resize + + std::string mtag; + size_t id; +}; + +class +jit_block : public jit_value +{ +public: + typedef std::list instruction_list; + typedef instruction_list::iterator iterator; + typedef instruction_list::const_iterator const_iterator; + + jit_block (const std::string& n) : nm (n) {} + + virtual ~jit_block () + { + for (instruction_list::iterator iter = instructions.begin (); + iter != instructions.end (); ++iter) + delete *iter; + } + + const std::string& name (void) const { return nm; } + + jit_instruction *prepend (jit_instruction *instr) + { + instructions.push_front (instr); + return instr; + } + + jit_instruction *append (jit_instruction *instr) + { + instructions.push_back (instr); + return instr; + } + + iterator begin () { return instructions.begin (); } + + const_iterator begin () const { return instructions.begin (); } + + iterator end () { return instructions.end (); } + + const_iterator end () const { return instructions.begin (); } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + print_indent (os, indent) << nm << ":" << std::endl; + for (iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->print (os, indent + 1) << std::endl; + } + return os; + } + + llvm::BasicBlock *to_llvm (void) const; + + JIT_VALUE_ACCEPT (block) +private: + std::string nm; + instruction_list instructions; +}; + +class jit_terminator : public jit_instruction +{ +public: + jit_terminator (jit_value *arg0) : jit_instruction (arg0) {} + + jit_terminator (jit_value *arg0, jit_value *arg1, jit_value *arg2) + : jit_instruction (arg0, arg1, arg2) {} + + virtual jit_block *sucessor (size_t idx = 0) const = 0; + + llvm::BasicBlock *sucessor_llvm (size_t idx = 0) const + { + return sucessor (idx)->to_llvm (); + } + + virtual size_t sucessor_count (void) const = 0; +}; + +class +jit_break : public jit_terminator +{ +public: + jit_break (jit_block *succ) : jit_terminator (succ) {} + + jit_block *sucessor (size_t idx = 0) const + { + jit_value *arg = argument (idx); + return reinterpret_cast (arg); + } + + size_t sucessor_count (void) const { return 1; } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + jit_block *succ = sucessor (); + return print_indent (os, indent) << "break: " << succ->name (); + } + + JIT_VALUE_ACCEPT (break) +}; + +class +jit_cond_break : public jit_terminator +{ +public: + jit_cond_break (jit_value *c, jit_block *ctrue, jit_block *cfalse) + : jit_terminator (c, ctrue, cfalse) {} + + jit_value *cond (void) const { return argument (0); } + + llvm::Value *cond_llvm (void) const + { + return cond ()->to_llvm (); + } + + jit_block *sucessor (size_t idx) const + { + jit_value *arg = argument (idx + 1); + return reinterpret_cast (arg); + } + + size_t sucessor_count (void) const { return 2; } + + JIT_VALUE_ACCEPT (cond_break) +}; + +class +jit_call : public jit_instruction +{ +public: + jit_call (const jit_function& afunction, + jit_value *arg0) : jit_instruction (arg0), mfunction (afunction) {} + + jit_call (const jit_function& afunction, + jit_value *arg0, jit_value *arg1) : jit_instruction (arg0, arg1), + mfunction (afunction) {} + + const jit_function& function (void) const { return mfunction; } + + const jit_function::overload& overload (void) const + { + return mfunction.get_overload (argument_types ()); + } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + print_indent (os, indent); + + if (use_count ()) + short_print (os) << " = "; + os << "call " << mfunction.name () << " ("; + + for (size_t i = 0; i < argument_count (); ++i) + { + jit_value *arg = argument (i); + arg->short_print (os); + if (i + 1 < argument_count ()) + os << ", "; + } + return os << ")"; + } + + virtual bool infer (void); + + JIT_VALUE_ACCEPT (call) +private: + const jit_function& mfunction; +}; + +class +jit_extract_argument : public jit_instruction +{ +public: + jit_extract_argument (jit_type *atype, const std::string& aname) + : jit_instruction () + { + stash_type (atype); + stash_tag (aname); + } + + const jit_function::overload& overload (void) const + { + return jit_typeinfo::cast (type (), jit_typeinfo::get_any ()); + } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + print_indent (os, indent); + return short_print (os) << " = extract: " << tag (); + } + + JIT_VALUE_ACCEPT (extract_argument) +}; + +class +jit_store_argument : public jit_instruction +{ +public: + jit_store_argument (const std::string& aname, jit_value *aresult) + : jit_instruction (aresult) + { + stash_tag (aname); + } + + const jit_function::overload& overload (void) const + { + return jit_typeinfo::cast (jit_typeinfo::get_any (), result_type ()); + } + + jit_value *result (void) const + { + return argument (0); + } + + jit_type *result_type (void) const + { + return result ()->type (); + } + + llvm::Value *result_llvm (void) const + { + return result ()->to_llvm (); + } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + jit_value *res = result (); + print_indent (os, indent) << tag () << " <- "; + return res->short_print (os); + } + + JIT_VALUE_ACCEPT (store_argument) +}; + +// convert between IRs +// FIXME: Class relationships are messy from here on down. They need to be +// cleaned up. +class +jit_convert : public tree_walker +{ +public: + typedef std::pair type_bound; + typedef std::vector type_bound_vector; + + jit_convert (llvm::Module *module, tree &tee); + + llvm::Function *get_function (void) const { return function; } + + const std::vector >& get_arguments(void) const + { return arguments; } + + const type_bound_vector& get_bounds (void) const { return bounds; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); +private: + std::vector > arguments; + type_bound_vector bounds; + + typedef std::map variable_map; + variable_map variables; + + // used instead of return values from visit_* functions + jit_value *result; + + jit_block *block; + jit_block *entry_block; + jit_block *final_block; + + llvm::Function *function; + + std::list blocks; + + std::list worklist; + + std::list constants; + + void do_assign (const std::string& lhs, jit_value *rhs, bool print); + + jit_value *visit (tree *tee) { return visit (*tee); } + + jit_value *visit (tree& tee); + + void append_users (jit_value *v) + { + for (jit_use *use = v->first_use (); use; use = use->next ()) + worklist.push_back (use->user ()); + } + + jit_const_scalar *get_scalar (double v) + { + jit_const_scalar *ret = new jit_const_scalar (v); + constants.push_back (ret); + return ret; + } + + jit_const_string *get_string (const std::string& v) + { + jit_const_string *ret = new jit_const_string (v); + constants.push_back (ret); + return ret; + } + + // this case is much simpler, just convert from the jit ir to llvm + class + convert_llvm : public jit_ir_walker + { + public: + llvm::Function *convert (llvm::Module *module, + const std::vector >& args, + const std::list& blocks, + const std::list& constants); + +#define JIT_METH(clname) \ + virtual void visit_ ## clname (jit_ ## clname&); + + JIT_VISIT_IR_CLASSES; + +#undef JIT_METH + private: + // name -> llvm argument + std::map arguments; + + + void visit (jit_value *jvalue) + { + return visit (*jvalue); + } + + void visit (jit_value &jvalue) + { + jvalue.accept (*this); + } + }; +}; + +class jit_info; + +class +tree_jit +{ +public: + tree_jit (void); + + ~tree_jit (void); + + bool execute (tree& cmd); + + llvm::ExecutionEngine *get_engine (void) const { return engine; } + + llvm::Module *get_module (void) const { return module; } + + void optimize (llvm::Function *fn); + private: + bool initialize (void); + + // FIXME: Temorary hack to test + typedef std::map compiled_map; + compiled_map compiled; + + llvm::LLVMContext &context; + llvm::Module *module; + llvm::PassManager *module_pass_manager; + llvm::FunctionPassManager *pass_manager; + llvm::ExecutionEngine *engine; +}; + +class +jit_info +{ +public: + jit_info (tree_jit& tjit, tree& tee); + + bool execute (void) const; + + bool match (void) const; +private: + typedef jit_convert::type_bound type_bound; + typedef jit_convert::type_bound_vector type_bound_vector; + typedef void (*jited_function)(octave_base_value**); + + llvm::ExecutionEngine *engine; + jited_function function; + + std::vector > arguments; + type_bound_vector bounds; +}; + +#endif diff -r 22244a235fd0 -r f0499b0af646 src/pt-loop.cc --- a/src/pt-loop.cc Thu May 24 15:38:59 2012 -0400 +++ b/src/pt-loop.cc Thu May 24 15:48:10 2012 -0600 @@ -35,6 +35,7 @@ #include "pt-bp.h" #include "pt-cmd.h" #include "pt-exp.h" +#include "pt-jit.h" #include "pt-jump.h" #include "pt-loop.h" #include "pt-stmt.h" @@ -97,6 +98,10 @@ delete list; delete lead_comm; delete trail_comm; + + for (compiled_map::iterator iter = compiled.begin (); iter != compiled.end (); + ++iter) + delete iter->second; } tree_command * diff -r 22244a235fd0 -r f0499b0af646 src/pt-loop.h --- a/src/pt-loop.h Thu May 24 15:38:59 2012 -0400 +++ b/src/pt-loop.h Thu May 24 15:48:10 2012 -0600 @@ -36,6 +36,9 @@ #include "pt-cmd.h" #include "symtab.h" +class jit_info; +class jit_type; + // While. class @@ -180,7 +183,20 @@ void accept (tree_walker& tw); + // some functions use by tree_jit + jit_info *get_info (jit_type *type) const + { + compiled_map::const_iterator iter = compiled.find (type); + return iter != compiled.end () ? iter->second : 0; + } + + void stash_info (jit_type *type, jit_info *jinfo) + { + compiled[type] = jinfo; + } + private: + typedef std::map compiled_map; // TRUE means operate in parallel (subject to the value of the // maxproc expression). @@ -205,6 +221,9 @@ // Comment preceding ENDFOR token. octave_comment_list *trail_comm; + // a map from iterator types -> compiled functions + compiled_map compiled; + // No copying! tree_simple_for_command (const tree_simple_for_command&); diff -r 22244a235fd0 -r f0499b0af646 src/pt-stmt.h --- a/src/pt-stmt.h Thu May 24 15:38:59 2012 -0400 +++ b/src/pt-stmt.h Thu May 24 15:48:10 2012 -0600 @@ -35,12 +35,13 @@ #include "base-list.h" #include "comment-list.h" #include "symtab.h" +#include "pt.h" // A statement is either a command to execute or an expression to // evaluate. class -tree_statement +tree_statement : public tree { public: diff -r 22244a235fd0 -r f0499b0af646 src/symtab.h --- a/src/symtab.h Thu May 24 15:38:59 2012 -0400 +++ b/src/symtab.h Thu May 24 15:48:10 2012 -0600 @@ -484,7 +484,7 @@ return symbol_record (rep->dup (new_scope)); } - std::string name (void) const { return rep->name; } + const std::string& name (void) const { return rep->name; } octave_value find (const octave_value_list& args = octave_value_list ()) const; @@ -581,6 +581,66 @@ symbol_record (symbol_record_rep *new_rep) : rep (new_rep) { } }; + // Always access a symbol from the current scope. + // Useful for scripts, as they may be executed in more than one scope. + class + symbol_record_ref + { + public: + symbol_record_ref (void) : scope (-1) {} + + symbol_record_ref (symbol_record record, + scope_id curr_scope = symbol_table::current_scope ()) + : scope (curr_scope), sym (record) + {} + + symbol_record_ref& operator = (const symbol_record_ref& ref) + { + scope = ref.scope; + sym = ref.sym; + return *this; + } + + // The name is the same regardless of scope. + const std::string& name (void) const { return sym.name (); } + + symbol_record *operator-> (void) + { + update (); + return &sym; + } + + symbol_record *operator-> (void) const + { + update (); + return &sym; + } + + // can be used to place symbol_record_ref in maps, we don't overload < as + // it doesn't make any sense for symbol_record_ref + struct comparator + { + bool operator ()(const symbol_record_ref& lhs, + const symbol_record_ref& rhs) const + { + return lhs.name () < rhs.name (); + } + }; + private: + void update (void) const + { + scope_id curr_scope = symbol_table::current_scope (); + if (scope != curr_scope || ! sym.is_valid ()) + { + scope = curr_scope; + sym = symbol_table::insert (sym.name ()); + } + } + + mutable scope_id scope; + mutable symbol_record sym; + }; + class fcn_info { diff -r 22244a235fd0 -r f0499b0af646 src/toplev.cc --- a/src/toplev.cc Thu May 24 15:38:59 2012 -0400 +++ b/src/toplev.cc Thu May 24 15:48:10 2012 -0600 @@ -1325,6 +1325,9 @@ { false, "MAGICK_CPPFLAGS", OCTAVE_CONF_MAGICK_CPPFLAGS }, { false, "MAGICK_LDFLAGS", OCTAVE_CONF_MAGICK_LDFLAGS }, { false, "MAGICK_LIBS", OCTAVE_CONF_MAGICK_LIBS }, + { false, "LLVM_CPPFLAGS", OCTAVE_CONF_LLVM_CPPFLAGS }, + { false, "LLVM_LDFLAGS", OCTAVE_CONF_LLVM_LDFLAGS }, + { false, "LLVM_LIBS", OCTAVE_CONF_LLVM_LIBS }, { false, "MKOCTFILE_DL_LDFLAGS", OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS }, { false, "OCTAVE_LINK_DEPS", OCTAVE_CONF_OCTAVE_LINK_DEPS }, { false, "OCTAVE_LINK_OPTS", OCTAVE_CONF_OCTAVE_LINK_OPTS },