# HG changeset patch # User Jordi GutiƩrrez Hermoso # Date 1337804525 14400 # Node ID 0b0569667939451c8d7332fb5563c3a4df2666b6 # Parent 757f729fd41dc4fed00342050f25448650d5be36# Parent cba58541954c4373ab034eb12d5c20bbe02d060d maint: Periodic merge of default to jit diff -r 757f729fd41d -r 0b0569667939 build-aux/common.mk --- a/build-aux/common.mk Wed May 23 13:36:24 2012 -0400 +++ b/build-aux/common.mk Wed May 23 16:22:05 2012 -0400 @@ -544,6 +544,9 @@ -e "s|%OCTAVE_CONF_MAGICK_CPPFLAGS%|\"${MAGICK_CPPFLAGS}\"|" \ -e "s|%OCTAVE_CONF_MAGICK_LDFLAGS%|\"${MAGICK_LDFLAGS}\"|" \ -e "s|%OCTAVE_CONF_MAGICK_LIBS%|\"${MAGICK_LIBS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_CPPFLAGS%|\"${LLVM_CPPFLAGS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_LDFLAGS%|\"${LLVM_LDFLAGS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_LIBS%|\"${LLVM_LIBS}\"|" \ -e 's|%OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS%|\"@MKOCTFILE_DL_LDFLAGS@\"|' \ -e "s|%OCTAVE_CONF_OCTAVE_LINK_DEPS%|\"${OCTAVE_LINK_DEPS}\"|" \ -e "s|%OCTAVE_CONF_OCTAVE_LINK_OPTS%|\"${OCTAVE_LINK_OPTS}\"|" \ diff -r 757f729fd41d -r 0b0569667939 configure.ac --- a/configure.ac Wed May 23 13:36:24 2012 -0400 +++ b/configure.ac Wed May 23 16:22:05 2012 -0400 @@ -712,6 +712,59 @@ [ZLIB library not found. Octave will not be able to save or load compressed data files or HDF5 files.], [zlib.h], [gzclearerr]) +### Check for the llvm library +dnl +dnl +dnl llvm is odd and has its own pkg-config like script. We should probably check +dnl for existance and +dnl + +LLVM_CONFIG=llvm-config +LLVM_CPPFLAGS= +LLVM_LDFLAGS= +LLVM_LIBS= + +LLVM_LDFLAGS=`$LLVM_CONFIG --ldflags` +LLVM_LIBS=`$LLVM_CONFIG --libs` +LLVM_CPPFLAGS=`$LLVM_CONFIG --cxxflags` + +warn_llvm="LLVM library fails tests. JIT compilation will be disabled." + +save_CPPFLAGS="$CPPFLAGS" +save_LIBS="$LIBS" +save_LDFLAGS="$LDFLAGS" +CPPFLAGS="$LLVM_CPPFLAGS $CPPFLAGS" +LIBS="$LLVM_LIBS $LIBS" +LDFLAGS="$LLVM_LDFLAGS $LDFLAGS" +AC_LANG_PUSH(C++) + AC_CHECK_HEADER([llvm/LLVMContext.h], [ + AC_MSG_CHECKING([for llvm::getGlobalContext in llvm/LLVMContext.h]) + AC_TRY_LINK([#include ], [llvm::getGlobalContext ()], [ + AC_MSG_RESULT(yes) + warn_llvm= + ], [ + AC_MSG_RESULT(no) + ]) + ]) +AC_LANG_POP(C++) +CPPFLAGS="$save_CPPFLAGS" +LIBS="$save_LIBS" +LDFLAGS="$save_LDFLAGS" + +if test -z "$warn_llvm"; then + AC_DEFINE(HAVE_LLVM, 1, [Define if LLVM is available.]) +else + LLVM_CPPFLAGS= + LLVM_LDFLAGS= + LLVM_LIBS= + AC_MSG_WARN([$warn_llvm]) +fi +AC_DEFINE(HAVE_LLVM, 1, [Define if LLVM is available.]) + +AC_SUBST(LLVM_CPPFLAGS) +AC_SUBST(LLVM_LDFLAGS) +AC_SUBST(LLVM_LIBS) + ### Check for HDF5 library. save_CPPFLAGS="$CPPFLAGS" @@ -2242,6 +2295,9 @@ Magick++ CPPFLAGS: $MAGICK_CPPFLAGS Magick++ LDFLAGS: $MAGICK_LDFLAGS Magick++ libraries: $MAGICK_LIBS + LLVM CPPFLAGS: $LLVM_CPPFLAGS + LLVM LDFLAGS: $LLVM_LDFLAGS + LLVM Libraries: $LLVM_LIBS HDF5 CPPFLAGS: $HDF5_CPPFLAGS HDF5 LDFLAGS: $HDF5_LDFLAGS HDF5 libraries: $HDF5_LIBS diff -r 757f729fd41d -r 0b0569667939 src/Makefile.am --- a/src/Makefile.am Wed May 23 13:36:24 2012 -0400 +++ b/src/Makefile.am Wed May 23 16:22:05 2012 -0400 @@ -220,6 +220,7 @@ pt-fcn-handle.h \ pt-id.h \ pt-idx.h \ + pt-jit.h \ pt-jump.h \ pt-loop.h \ pt-mat.h \ @@ -392,6 +393,7 @@ pt-fcn-handle.cc \ pt-id.cc \ pt-idx.cc \ + pt-jit.cc \ pt-jump.cc \ pt-loop.cc \ pt-mat.cc \ diff -r 757f729fd41d -r 0b0569667939 src/TEMPLATE-INST/Array-jit.cc --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/TEMPLATE-INST/Array-jit.cc Wed May 23 16:22:05 2012 -0400 @@ -0,0 +1,34 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "Array.h" +#include "Array.cc" + +#include "pt-jit.h" + +NO_INSTANTIATE_ARRAY_SORT (jit_function::overload); + +INSTANTIATE_ARRAY (jit_function::overload, OCTINTERP_API); diff -r 757f729fd41d -r 0b0569667939 src/TEMPLATE-INST/module.mk --- a/src/TEMPLATE-INST/module.mk Wed May 23 13:36:24 2012 -0400 +++ b/src/TEMPLATE-INST/module.mk Wed May 23 16:22:05 2012 -0400 @@ -2,4 +2,5 @@ TEMPLATE_INST_SRC = \ TEMPLATE-INST/Array-os.cc \ - TEMPLATE-INST/Array-tc.cc + TEMPLATE-INST/Array-tc.cc \ + TEMPLATE-INST/Array-jit.cc diff -r 757f729fd41d -r 0b0569667939 src/link-deps.mk --- a/src/link-deps.mk Wed May 23 13:36:24 2012 -0400 +++ b/src/link-deps.mk Wed May 23 16:22:05 2012 -0400 @@ -13,14 +13,16 @@ $(Z_LIBS) \ $(OPENGL_LIBS) \ $(X11_LIBS) \ - $(CARBON_LIBS) + $(CARBON_LIBS) \ + $(LLVM_LIBS) LIBOCTINTERP_LINK_OPTS = \ $(GRAPHICS_LDFLAGS) \ $(FT2_LDFLAGS) \ $(HDF5_LDFLAGS) \ $(Z_LDFLAGS) \ - $(REGEX_LDFLAGS) + $(REGEX_LDFLAGS) \ + $(LLVM_LDFLAGS) OCT_LINK_DEPS = diff -r 757f729fd41d -r 0b0569667939 src/oct-conf.in.h --- a/src/oct-conf.in.h Wed May 23 13:36:24 2012 -0400 +++ b/src/oct-conf.in.h Wed May 23 16:22:05 2012 -0400 @@ -384,6 +384,18 @@ #define OCTAVE_CONF_MAGICK_LIBS %OCTAVE_CONF_MAGICK_LIBS% #endif +#ifndef OCTAVE_CONF_LLVM_CPPFLAGS +#define OCTAVE_CONF_LLVM_CPPFLAGS %OCTAVE_CONF_LLVM_CPPFLAGS% +#endif + +#ifndef OCTAVE_CONF_LLVM_LDFLAGS +#define OCTAVE_CONF_LLVM_LDFLAGS %OCTAVE_CONF_LLVM_LDFLAGS% +#endif + +#ifndef OCTAVE_CONF_LLVM_LIBS +#define OCTAVE_CONF_LLVM_LIBS %OCTAVE_CONF_LLVM_LIBS% +#endif + #ifndef OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS #define OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS %OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS% #endif diff -r 757f729fd41d -r 0b0569667939 src/ov-base.h --- a/src/ov-base.h Wed May 23 13:36:24 2012 -0400 +++ b/src/ov-base.h Wed May 23 16:22:05 2012 -0400 @@ -755,6 +755,21 @@ virtual bool fast_elem_insert_self (void *where, builtin_type_t btyp) const; + // Grab the reference count. For use by jit. + void + grab (void) + { + ++count; + } + + // Release the reference count. For use by jit. + void + release (void) + { + if (--count == 0) + delete this; + } + protected: // This should only be called for derived types. diff -r 757f729fd41d -r 0b0569667939 src/pt-eval.cc --- a/src/pt-eval.cc Wed May 23 13:36:24 2012 -0400 +++ b/src/pt-eval.cc Wed May 23 16:22:05 2012 -0400 @@ -44,6 +44,10 @@ #include "symtab.h" #include "unwind-prot.h" +//FIXME: This should be part of tree_evaluator +#include "pt-jit.h" +static tree_jit jiter; + static tree_evaluator std_evaluator; tree_evaluator *current_evaluator = &std_evaluator; @@ -306,6 +310,9 @@ if (error_state || rhs.is_undefined ()) return; + if (jiter.execute (cmd, rhs)) + return; + { tree_expression *lhs = cmd.left_hand_side (); diff -r 757f729fd41d -r 0b0569667939 src/pt-id.cc --- a/src/pt-id.cc Wed May 23 13:36:24 2012 -0400 +++ b/src/pt-id.cc Wed May 23 16:22:05 2012 -0400 @@ -63,7 +63,7 @@ if (error_state) return retval; - octave_value val = xsym ().find (); + octave_value val = sym->find (); if (val.is_defined ()) { @@ -114,7 +114,7 @@ octave_lvalue tree_identifier::lvalue (void) { - return octave_lvalue (&(xsym().varref ())); + return octave_lvalue (&(sym->varref ())); } tree_identifier * diff -r 757f729fd41d -r 0b0569667939 src/pt-id.h --- a/src/pt-id.h Wed May 23 13:36:24 2012 -0400 +++ b/src/pt-id.h Wed May 23 16:22:05 2012 -0400 @@ -46,12 +46,12 @@ public: tree_identifier (int l = -1, int c = -1) - : tree_expression (l, c), sym (), scope (-1) { } + : tree_expression (l, c) { } tree_identifier (const symbol_table::symbol_record& s, int l = -1, int c = -1, symbol_table::scope_id sc = symbol_table::current_scope ()) - : tree_expression (l, c), sym (s), scope (sc) { } + : tree_expression (l, c), sym (s, sc) { } ~tree_identifier (void) { } @@ -63,9 +63,9 @@ // accessing it through sym so that this function may remain const. std::string name (void) const { return sym.name (); } - bool is_defined (void) { return xsym().is_defined (); } + bool is_defined (void) { return sym->is_defined (); } - virtual bool is_variable (void) { return xsym().is_variable (); } + virtual bool is_variable (void) { return sym->is_variable (); } virtual bool is_black_hole (void) { return false; } @@ -87,14 +87,14 @@ octave_value do_lookup (const octave_value_list& args = octave_value_list ()) { - return xsym().find (args); + return sym->find (args); } - void mark_global (void) { xsym().mark_global (); } + void mark_global (void) { sym->mark_global (); } - void mark_as_static (void) { xsym().init_persistent (); } + void mark_as_static (void) { sym->init_persistent (); } - void mark_as_formal_parameter (void) { xsym().mark_formal (); } + void mark_as_formal_parameter (void) { sym->mark_formal (); } // We really need to know whether this symbol referst to a variable // or a function, but we may not know that yet. @@ -114,28 +114,14 @@ void accept (tree_walker& tw); + symbol_table::symbol_record_ref symbol (void) const + { + return sym; + } private: // The symbol record that this identifier references. - symbol_table::symbol_record sym; - - symbol_table::scope_id scope; - - // A script may be executed in multiple scopes. If the last one was - // different from the one we are in now, update sym to be from the - // new scope. - symbol_table::symbol_record& xsym (void) - { - symbol_table::scope_id curr_scope = symbol_table::current_scope (); - - if (scope != curr_scope || ! sym.is_valid ()) - { - scope = curr_scope; - sym = symbol_table::insert (sym.name ()); - } - - return sym; - } + symbol_table::symbol_record_ref sym; // No copying! diff -r 757f729fd41d -r 0b0569667939 src/pt-jit.cc --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/pt-jit.cc Wed May 23 16:22:05 2012 -0400 @@ -0,0 +1,1866 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#define __STDC_LIMIT_MACROS +#define __STDC_CONSTANT_MACROS + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "pt-jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "octave.h" +#include "ov-fcn-handle.h" +#include "ov-usr-fcn.h" +#include "pt-all.h" + +// FIXME: Remove eventually +// For now we leave this in so people tell when JIT actually happens +static const bool debug_print = false; + +static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); + +// thrown when we should give up on JIT and interpret +class jit_fail_exception : public std::exception {}; + +static void +fail (void) +{ + throw jit_fail_exception (); +} + +// function that jit code calls +extern "C" void +octave_jit_print_any (const char *name, octave_base_value *obv) +{ + obv->print_with_name (octave_stdout, name, true); +} + +extern "C" void +octave_jit_print_double (const char *name, double value) +{ + // FIXME: We should avoid allocating a new octave_scalar each time + octave_value ov (value); + ov.print_with_name (octave_stdout, name); +} + +extern "C" octave_base_value* +octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs, + octave_base_value *rhs) +{ + octave_value olhs (lhs); + octave_value orhs (rhs); + octave_value result = do_binary_op (op, olhs, orhs); + octave_base_value *rep = result.internal_rep (); + rep->grab (); + return rep; +} + +extern "C" void +octave_jit_release_any (octave_base_value *obv) +{ + obv->release (); +} + +extern "C" void +octave_jit_grab_any (octave_base_value *obv) +{ + obv->grab (); +} + +// -------------------- jit_type -------------------- +llvm::Type * +jit_type::to_llvm_arg (void) const +{ + return llvm_type ? llvm_type->getPointerTo () : 0; +} + +// -------------------- jit_function -------------------- +void +jit_function::add_overload (const overload& func, + const std::vector& args) +{ + if (args.size () >= overloads.size ()) + overloads.resize (args.size () + 1); + + Array& over = overloads[args.size ()]; + dim_vector dv (over.dims ()); + Array idx = to_idx (args); + bool must_resize = false; + + if (dv.length () != idx.numel ()) + { + dv.resize (idx.numel ()); + must_resize = true; + } + + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (dv(i) <= idx(i)) + { + must_resize = true; + dv(i) = idx(i) + 1; + } + + if (must_resize) + over.resize (dv); + + over(idx) = func; +} + +const jit_function::overload& +jit_function::get_overload (const std::vector& types) const +{ + // FIXME: We should search for the next best overload on failure + static overload null_overload; + if (types.size () >= overloads.size ()) + return null_overload; + + const Array& over = overloads[types.size ()]; + dim_vector dv (over.dims ()); + Array idx = to_idx (types); + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (idx(i) >= dv(i)) + return null_overload; + + return over(idx); +} + +Array +jit_function::to_idx (const std::vector& types) const +{ + octave_idx_type numel = types.size (); + if (numel == 1) + numel = 2; + + Array idx (dim_vector (1, numel)); + for (octave_idx_type i = 0; i < static_cast (types.size ()); + ++i) + idx(i) = types[i]->type_id (); + + if (types.size () == 1) + { + idx(1) = idx(0); + idx(0) = 0; + } + + return idx; +} + +// -------------------- jit_typeinfo -------------------- +jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) + : module (m), engine (e), next_id (0) +{ + // FIXME: We should be registering types like in octave_value_typeinfo + llvm::LLVMContext &ctx = m->getContext (); + + ov_t = llvm::StructType::create (ctx, "octave_base_value"); + ov_t = ov_t->getPointerTo (); + + llvm::Type *dbl = llvm::Type::getDoubleTy (ctx); + llvm::Type *bool_t = llvm::Type::getInt1Ty (ctx); + llvm::Type *index_t = 0; + switch (sizeof(octave_idx_type)) + { + case 4: + index_t = llvm::Type::getInt32Ty (ctx); + break; + case 8: + index_t = llvm::Type::getInt64Ty (ctx); + break; + default: + assert (false && "Unrecognized index type size"); + } + + llvm::StructType *range_t = llvm::StructType::create (ctx, "range"); + std::vector range_contents (4, dbl); + range_contents[3] = index_t; + range_t->setBody (range_contents); + + // create types + any = new_type ("any", true, 0, ov_t); + scalar = new_type ("scalar", false, any, dbl); + range = new_type ("range", false, any, range_t); + boolean = new_type ("bool", false, any, bool_t); + index = new_type ("index", false, any, index_t); + + // any with anything is an any op + llvm::Function *fn; + llvm::Type *binary_op_type + = llvm::Type::getIntNTy (ctx, sizeof (octave_value::binary_op)); + llvm::Function *any_binary = create_function ("octave_jit_binary_any_any", + any->to_llvm (), binary_op_type, + any->to_llvm (), any->to_llvm ()); + engine->addGlobalMapping (any_binary, + reinterpret_cast(&octave_jit_binary_any_any)); + + binary_ops.resize (octave_value::num_binary_ops); + for (int op = 0; op < octave_value::num_binary_ops; ++op) + { + llvm::Twine fn_name ("octave_jit_binary_any_any_"); + fn_name = fn_name + llvm::Twine (op); + fn = create_function (fn_name, any, any, any); + llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (block); + llvm::APInt op_int(sizeof (octave_value::binary_op), op, + std::numeric_limits::is_signed); + llvm::Value *op_as_llvm = llvm::ConstantInt::get (binary_op_type, op_int); + llvm::Value *ret = builder.CreateCall3 (any_binary, + op_as_llvm, + fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + + jit_function::overload overload (fn, true, any, any, any); + for (octave_idx_type i = 0; i < next_id; ++i) + binary_ops[op].add_overload (overload); + } + + llvm::Type *void_t = llvm::Type::getVoidTy (ctx); + + // grab any + fn = create_function ("octave_jit_grab_any", void_t, any->to_llvm ()); + + engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_grab_any)); + grab_fn.add_overload (fn, false, 0, any); + + // release any + fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); + engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_release_any)); + release_fn.add_overload (fn, false, 0, any); + + // now for binary scalar operations + // FIXME: Finish all operations + add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); + add_binary_op (scalar, octave_value::op_mul, llvm::Instruction::FMul); + add_binary_op (scalar, octave_value::op_el_mul, llvm::Instruction::FMul); + + // FIXME: Warn if rhs is zero + add_binary_op (scalar, octave_value::op_div, llvm::Instruction::FDiv); + add_binary_op (scalar, octave_value::op_el_div, llvm::Instruction::FDiv); + + add_binary_fcmp (scalar, octave_value::op_lt, llvm::CmpInst::FCMP_ULT); + add_binary_fcmp (scalar, octave_value::op_le, llvm::CmpInst::FCMP_ULE); + add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ); + add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE); + add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); + add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); + + // now for printing functions + add_print (any, reinterpret_cast (&octave_jit_print_any)); + add_print (scalar, reinterpret_cast (&octave_jit_print_double)); + + // bounds check for for loop + fn = create_function ("octave_jit_simple_for_range", boolean, range, index); + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *nelem + = builder.CreateExtractValue (fn->arg_begin (), 3); + // llvm::Value *idx = builder.CreateLoad (++fn->arg_begin ()); + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *ret = builder.CreateICmpULT (idx, nelem); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_check.add_overload (fn, false, boolean, range, index); + + // increment for for loop + fn = create_function ("octave_jit_imple_for_range_incr", index, index); + body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + llvm::Value *idx = fn->arg_begin (); + llvm::Value *ret = builder.CreateAdd (idx, one); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_incr.add_overload (fn, false, index, index); + + // index variabe for for loop + fn = create_function ("octave_jit_simple_for_idx", scalar, range, index); + body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *didx = builder.CreateUIToFP (idx, dbl); + llvm::Value *rng = fn->arg_begin (); + llvm::Value *base = builder.CreateExtractValue (rng, 0); + llvm::Value *inc = builder.CreateExtractValue (rng, 2); + + llvm::Value *ret = builder.CreateFMul (didx, inc); + ret = builder.CreateFAdd (base, ret); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_index.add_overload (fn, false, scalar, range, index); + + // logically true + // FIXME: Check for NaN + fn = create_function ("octave_logically_true_scalar", boolean, scalar); + body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantFP::get (scalar->to_llvm (), 0); + llvm::Value *ret = builder.CreateFCmpUNE (fn->arg_begin (), zero); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + logically_true.add_overload (fn, true, boolean, scalar); + + fn = create_function ("octave_logically_true_bool", boolean, boolean); + body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + builder.CreateRet (fn->arg_begin ()); + llvm::verifyFunction (*fn); + logically_true.add_overload (fn, false, boolean, boolean); +} + +void +jit_typeinfo::add_print (jit_type *ty, void *call) +{ + std::stringstream name; + name << "octave_jit_print_" << ty->name (); + + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Type *void_t = llvm::Type::getVoidTy (ctx); + llvm::Function *fn = create_function (name.str (), void_t, + llvm::Type::getInt8PtrTy (ctx), + ty->to_llvm ()); + engine->addGlobalMapping (fn, call); + + jit_function::overload ol (fn, false, 0, ty); + print_fn.add_overload (ol); +} + +// FIXME: cp between add_binary_op, add_binary_icmp, and add_binary_fcmp +void +jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Function *fn = create_function (fname.str (), ty, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (block); + llvm::Instruction::BinaryOps temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateBinOp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_function::overload ol(fn, false, ty, ty, ty); + binary_ops[op].add_overload (ol); +} + +void +jit_typeinfo::add_binary_icmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateICmp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_function::overload ol (fn, false, boolean, ty, ty); + binary_ops[op].add_overload (ol); +} + +void +jit_typeinfo::add_binary_fcmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateFCmp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_function::overload ol (fn, false, boolean, ty, ty); + binary_ops[op].add_overload (ol); +} + +llvm::Function * +jit_typeinfo::create_function (const llvm::Twine& name, llvm::Type *ret, + const std::vector& args) +{ + llvm::FunctionType *ft = llvm::FunctionType::get (ret, args, false); + llvm::Function *fn = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + name, module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + return fn; +} + +jit_type* +jit_typeinfo::type_of (const octave_value &ov) const +{ + if (ov.is_undefined () || ov.is_function ()) + return 0; + + if (ov.is_double_type () && ov.is_real_scalar ()) + return get_scalar (); + + if (ov.is_range ()) + return get_range (); + + return get_any (); +} + +const jit_function& +jit_typeinfo::binary_op (int op) const +{ + assert (static_cast(op) < binary_ops.size ()); + return binary_ops[op]; +} + +const jit_function::overload& +jit_typeinfo::print_value (jit_type *to_print) const +{ + return print_fn.get_overload (to_print); +} + +void +jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv) +{ + if (type == any) + to_generic (type, gv, octave_value ()); + else if (type == scalar) + to_generic (type, gv, octave_value (0)); + else if (type == range) + to_generic (type, gv, octave_value (Range ())); + else + assert (false && "Type not supported yet"); +} + +void +jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov) +{ + if (type == any) + { + octave_base_value *obv = ov.internal_rep (); + obv->grab (); + ov_out.push_back (obv); + gv.PointerVal = &ov_out.back (); + } + else if (type == scalar) + { + scalar_out.push_back (ov.double_value ()); + gv.PointerVal = &scalar_out.back (); + } + else if (type == range) + { + range_out.push_back (ov.range_value ()); + gv.PointerVal = &range_out.back (); + } + else + assert (false && "Type not supported yet"); +} + +octave_value +jit_typeinfo::to_octave_value (jit_type *type, llvm::GenericValue& gv) +{ + if (type == any) + { + octave_base_value **ptr = reinterpret_cast(gv.PointerVal); + return octave_value (*ptr); + } + else if (type == scalar) + { + double *ptr = reinterpret_cast(gv.PointerVal); + return octave_value (*ptr); + } + else if (type == range) + { + jit_range *ptr = reinterpret_cast(gv.PointerVal); + Range rng = *ptr; + return octave_value (rng); + } + else + assert (false && "Type not supported yet"); +} + +void +jit_typeinfo::reset_generic (void) +{ + scalar_out.clear (); + ov_out.clear (); + range_out.clear (); +} + +jit_type* +jit_typeinfo::new_type (const std::string& name, bool force_init, + jit_type *parent, llvm::Type *llvm_type) +{ + jit_type *ret = new jit_type (name, force_init, parent, llvm_type, next_id++); + id_to_type.push_back (ret); + return ret; +} + +// -------------------- jit_infer -------------------- +void +jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds) +{ + infer_simple_for (cmd, bounds); +} + +void +jit_infer::visit_anon_fcn_handle (tree_anon_fcn_handle&) +{ + fail (); +} + +void +jit_infer::visit_argument_list (tree_argument_list&) +{ + fail (); +} + +void +jit_infer::visit_binary_expression (tree_binary_expression& be) +{ + if (is_lvalue) + fail (); + + if (be.op_type () >= octave_value::num_binary_ops) + fail (); + + tree_expression *lhs = be.lhs (); + lhs->accept (*this); + jit_type *tlhs = type_stack.back (); + type_stack.pop_back (); + + tree_expression *rhs = be.rhs (); + rhs->accept (*this); + jit_type *trhs = type_stack.back (); + + jit_type *result = tinfo->binary_op_result (be.op_type (), tlhs, trhs); + if (! result) + fail (); + + type_stack.push_back (result); +} + +void +jit_infer::visit_break_command (tree_break_command&) +{ + fail (); +} + +void +jit_infer::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +jit_infer::visit_continue_command (tree_continue_command&) +{ + fail (); +} + +void +jit_infer::visit_global_command (tree_global_command&) +{ + fail (); +} + +void +jit_infer::visit_persistent_command (tree_persistent_command&) +{ + fail (); +} + +void +jit_infer::visit_decl_elt (tree_decl_elt&) +{ + fail (); +} + +void +jit_infer::visit_decl_init_list (tree_decl_init_list&) +{ + fail (); +} + +void +jit_infer::visit_simple_for_command (tree_simple_for_command& cmd) +{ + tree_expression *control = cmd.control_expr (); + control->accept (*this); + + jit_type *control_t = type_stack.back (); + type_stack.pop_back (); + + // FIXME: We should improve type inference so we don't have to do this + // to generate nested for loop code + + // quick hack, check if the for loop bounds are const. If we + // run at least one, we don't have to merge types + bool atleast_once = false; + if (control->is_constant ()) + { + octave_value over = control->rvalue1 (); + if (over.is_range ()) + { + Range rng = over.range_value (); + atleast_once = rng.nelem () > 0; + } + } + + if (atleast_once) + infer_simple_for (cmd, control_t); + else + { + type_map fallthrough = types; + infer_simple_for (cmd, control_t); + merge (types, fallthrough); + } +} + +void +jit_infer::visit_complex_for_command (tree_complex_for_command&) +{ + fail (); +} + +void +jit_infer::visit_octave_user_script (octave_user_script&) +{ + fail (); +} + +void +jit_infer::visit_octave_user_function (octave_user_function&) +{ + fail (); +} + +void +jit_infer::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +jit_infer::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +jit_infer::visit_function_def (tree_function_def&) +{ + fail (); +} + +void +jit_infer::visit_identifier (tree_identifier& ti) +{ + symbol_table::symbol_record_ref record = ti.symbol (); + handle_identifier (record); +} + +void +jit_infer::visit_if_clause (tree_if_clause&) +{ + fail (); +} + +void +jit_infer::visit_if_command (tree_if_command& cmd) +{ + if (is_lvalue) + fail (); + + tree_if_command_list *lst = cmd.cmd_list (); + assert (lst); + lst->accept (*this); +} + +void +jit_infer::visit_if_command_list (tree_if_command_list& lst) +{ + // determine the types on each branch of the if seperatly, then merge + type_map fallthrough = types, last; + bool first_time = true; + for (tree_if_command_list::iterator p = lst.begin (); p != lst.end(); ++p) + { + tree_if_clause *tic = *p; + + if (! first_time) + types = fallthrough; + + if (! tic->is_else_clause ()) + { + tree_expression *expr = tic->condition (); + expr->accept (*this); + } + + fallthrough = types; + + tree_statement_list *stmt_lst = tic->commands (); + assert (stmt_lst); + stmt_lst->accept (*this); + + if (first_time) + last = types; + else + merge (last, types); + } + + types = last; + + tree_if_clause *last_clause = lst.back (); + if (! last_clause->is_else_clause ()) + merge (types, fallthrough); +} + +void +jit_infer::visit_index_expression (tree_index_expression&) +{ + fail (); +} + +void +jit_infer::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +jit_infer::visit_cell (tree_cell&) +{ + fail (); +} + +void +jit_infer::visit_multi_assignment (tree_multi_assignment&) +{ + fail (); +} + +void +jit_infer::visit_no_op_command (tree_no_op_command&) +{ + fail (); +} + +void +jit_infer::visit_constant (tree_constant& tc) +{ + if (is_lvalue) + fail (); + + octave_value v = tc.rvalue1 (); + jit_type *type = tinfo->type_of (v); + if (! type) + fail (); + + type_stack.push_back (type); +} + +void +jit_infer::visit_fcn_handle (tree_fcn_handle&) +{ + fail (); +} + +void +jit_infer::visit_parameter_list (tree_parameter_list&) +{ + fail (); +} + +void +jit_infer::visit_postfix_expression (tree_postfix_expression&) +{ + fail (); +} + +void +jit_infer::visit_prefix_expression (tree_prefix_expression&) +{ + fail (); +} + +void +jit_infer::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +jit_infer::visit_return_list (tree_return_list&) +{ + fail (); +} + +void +jit_infer::visit_simple_assignment (tree_simple_assignment& tsa) +{ + if (is_lvalue) + fail (); + + // resolve rhs + is_lvalue = false; + tree_expression *rhs = tsa.right_hand_side (); + rhs->accept (*this); + + jit_type *trhs = type_stack.back (); + type_stack.pop_back (); + + // resolve lhs + is_lvalue = true; + rvalue_type = trhs; + tree_expression *lhs = tsa.left_hand_side (); + lhs->accept (*this); + + // we don't pop back here, as the resulting type should be the rhs type + // which is equal to the lhs type anways + jit_type *tlhs = type_stack.back (); + if (tlhs != trhs) + fail (); + + is_lvalue = false; + rvalue_type = 0; +} + +void +jit_infer::visit_statement (tree_statement& stmt) +{ + if (is_lvalue) + fail (); + + tree_command *cmd = stmt.command (); + tree_expression *expr = stmt.expression (); + + if (cmd) + cmd->accept (*this); + else + { + // ok, this check for ans appears three times as cp + bool do_bind_ans = false; + + if (expr->is_identifier ()) + { + tree_identifier *id = dynamic_cast (expr); + + do_bind_ans = (! id->is_variable ()); + } + else + do_bind_ans = (! expr->is_assignment_expression ()); + + expr->accept (*this); + + if (do_bind_ans) + { + is_lvalue = true; + rvalue_type = type_stack.back (); + type_stack.pop_back (); + + symbol_table::symbol_record_ref record (symbol_table::insert ("ans")); + handle_identifier (record); + + if (rvalue_type != type_stack.back ()) + fail (); + + is_lvalue = false; + rvalue_type = 0; + } + + type_stack.pop_back (); + } +} + +void +jit_infer::visit_statement_list (tree_statement_list& lst) +{ + tree_statement_list::iterator iter; + for (iter = lst.begin (); iter != lst.end (); ++iter) + { + tree_statement *stmt = *iter; + assert (stmt); // FIXME: jwe can this be null? + stmt->accept (*this); + } +} + +void +jit_infer::visit_switch_case (tree_switch_case&) +{ + fail (); +} + +void +jit_infer::visit_switch_case_list (tree_switch_case_list&) +{ + fail (); +} + +void +jit_infer::visit_switch_command (tree_switch_command&) +{ + fail (); +} + +void +jit_infer::visit_try_catch_command (tree_try_catch_command&) +{ + fail (); +} + +void +jit_infer::visit_unwind_protect_command (tree_unwind_protect_command&) +{ + fail (); +} + +void +jit_infer::visit_while_command (tree_while_command&) +{ + fail (); +} + +void +jit_infer::visit_do_until_command (tree_do_until_command&) +{ + fail (); +} + +void +jit_infer::infer_simple_for (tree_simple_for_command& cmd, + jit_type *bounds) +{ + if (is_lvalue) + fail (); + + jit_type *iter = tinfo->get_simple_for_index_result (bounds); + if (! iter) + fail (); + + is_lvalue = true; + rvalue_type = iter; + tree_expression *lhs = cmd.left_hand_side (); + lhs->accept (*this); + if (type_stack.back () != iter) + fail (); + type_stack.pop_back (); + is_lvalue = false; + rvalue_type = 0; + + tree_statement_list *body = cmd.body (); + body->accept (*this); +} + +void +jit_infer::handle_identifier (const symbol_table::symbol_record_ref& record) +{ + type_map::iterator iter = types.find (record); + if (iter == types.end ()) + { + jit_type *ty = tinfo->type_of (record->find ()); + bool argin = false; + if (is_lvalue) + { + if (! ty) + ty = rvalue_type; + } + else + { + if (! ty) + fail (); + argin = true; + } + + types[record] = type_entry (argin, ty); + type_stack.push_back (ty); + } + else + type_stack.push_back (iter->second.second); +} + +void +jit_infer::merge (type_map& dest, const type_map& src) +{ + if (dest.size () != src.size ()) + fail (); + + type_map::iterator dest_iter; + type_map::const_iterator src_iter; + for (dest_iter = dest.begin (), src_iter = src.begin (); + dest_iter != dest.end (); ++dest_iter, ++src_iter) + { + if (dest_iter->first.name () != src_iter->first.name () + || dest_iter->second.second != src_iter->second.second) + fail (); + + // require argin if one path requires argin + dest_iter->second.first = dest_iter->second.first + || src_iter->second.first; + } +} + +// -------------------- jit_generator -------------------- +jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod, + tree_simple_for_command& cmd, jit_type *bounds, + const type_map& infered_types) + : tinfo (ti), module (mod), is_lvalue (false) +{ + // create new vectors that include bounds + std::vector names (infered_types.size () + 1); + std::vector argin (infered_types.size () + 1); + std::vector types (infered_types.size () + 1); + names[0] = "#bounds"; + argin[0] = true; + types[0] = bounds; + size_t i; + type_map::const_iterator iter; + for (i = 1, iter = infered_types.begin (); iter != infered_types.end (); + ++i, ++iter) + { + names[i] = iter->first.name (); + argin[i] = iter->second.first; + types[i] = iter->second.second; + } + + initialize (names, argin, types); + + try + { + value var_bounds = variables["#bounds"]; + var_bounds.second = builder.CreateLoad (var_bounds.second); + emit_simple_for (cmd, var_bounds, true); + } + catch (const jit_fail_exception&) + { + function->eraseFromParent (); + function = 0; + return; + } + + finalize (names); +} + +void +jit_generator::visit_anon_fcn_handle (tree_anon_fcn_handle&) +{ + fail (); +} + +void +jit_generator::visit_argument_list (tree_argument_list&) +{ + fail (); +} + +void +jit_generator::visit_binary_expression (tree_binary_expression& be) +{ + tree_expression *lhs = be.lhs (); + lhs->accept (*this); + value lhsv = value_stack.back (); + value_stack.pop_back (); + + tree_expression *rhs = be.rhs (); + rhs->accept (*this); + value rhsv = value_stack.back (); + value_stack.pop_back (); + + const jit_function::overload& ol + = tinfo->binary_op_overload (be.op_type (), lhsv.first, rhsv.first); + + if (! ol.function) + fail (); + + llvm::Value *result = builder.CreateCall2 (ol.function, lhsv.second, + rhsv.second); + push_value (ol.result, result); +} + +void +jit_generator::visit_break_command (tree_break_command&) +{ + fail (); +} + +void +jit_generator::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +jit_generator::visit_continue_command (tree_continue_command&) +{ + fail (); +} + +void +jit_generator::visit_global_command (tree_global_command&) +{ + fail (); +} + +void +jit_generator::visit_persistent_command (tree_persistent_command&) +{ + fail (); +} + +void +jit_generator::visit_decl_elt (tree_decl_elt&) +{ + fail (); +} + +void +jit_generator::visit_decl_init_list (tree_decl_init_list&) +{ + fail (); +} + +void +jit_generator::visit_simple_for_command (tree_simple_for_command& cmd) +{ + if (is_lvalue) + fail (); + + tree_expression *control = cmd.control_expr (); + assert (control); // FIXME: jwe, can this be null? + + control->accept (*this); + value over = value_stack.back (); + value_stack.pop_back (); + + emit_simple_for (cmd, over, false); +} + +void +jit_generator::visit_complex_for_command (tree_complex_for_command&) +{ + fail (); +} + +void +jit_generator::visit_octave_user_script (octave_user_script&) +{ + fail (); +} + +void +jit_generator::visit_octave_user_function (octave_user_function&) +{ + fail (); +} + +void +jit_generator::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +jit_generator::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +jit_generator::visit_function_def (tree_function_def&) +{ + fail (); +} + +void +jit_generator::visit_identifier (tree_identifier& ti) +{ + std::string name = ti.name (); + value variable = variables[name]; + if (is_lvalue) + { + value_stack.push_back (variable); + + const jit_function::overload& ol = tinfo->release (variable.first); + if (ol.function) + { + llvm::Value *load = builder.CreateLoad (variable.second, name); + builder.CreateCall (ol.function, load); + } + } + else + { + llvm::Value *load = builder.CreateLoad (variable.second, name); + push_value (variable.first, load); + + const jit_function::overload& ol = tinfo->grab (variable.first); + if (ol.function) + builder.CreateCall (ol.function, load); + } +} + +void +jit_generator::visit_if_clause (tree_if_clause&) +{ + fail (); +} + +void +jit_generator::visit_if_command (tree_if_command& cmd) +{ + tree_if_command_list *lst = cmd.cmd_list (); + assert (lst); + lst->accept (*this); +} + +void +jit_generator::visit_if_command_list (tree_if_command_list& lst) +{ + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "if_tail", function); + std::vector clause_entry (lst.size ()); + tree_if_command_list::iterator p; + size_t i; + for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i) + { + tree_if_clause *tic = *p; + if (tic->is_else_clause ()) + clause_entry[i] = llvm::BasicBlock::Create (ctx, "else_body", function, + tail); + else + clause_entry[i] = llvm::BasicBlock::Create (ctx, "if_cond", function, + tail); + } + + builder.CreateBr (clause_entry[0]); + + for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i) + { + tree_if_clause *tic = *p; + llvm::BasicBlock *body; + if (tic->is_else_clause ()) + body = clause_entry[i]; + else + { + llvm::BasicBlock *cond = clause_entry[i]; + builder.SetInsertPoint (cond); + + tree_expression *expr = tic->condition (); + expr->accept (*this); + + // FIXME: Handle undefined case + value condv = value_stack.back (); + value_stack.pop_back (); + + const jit_function::overload& ol = tinfo->get_logically_true (condv.first); + if (! ol.function) + fail (); + + bool last = i + 1 == clause_entry.size (); + llvm::BasicBlock *next = last ? tail : clause_entry[i + 1]; + body = llvm::BasicBlock::Create (ctx, "if_body", function, tail); + + llvm::Value *is_true = builder.CreateCall (ol.function, condv.second); + builder.CreateCondBr (is_true, body, next); + } + + tree_statement_list *stmt_lst = tic->commands (); + builder.SetInsertPoint (body); + stmt_lst->accept (*this); + builder.CreateBr (tail); + } + + builder.SetInsertPoint (tail); +} + +void +jit_generator::visit_index_expression (tree_index_expression&) +{ + fail (); +} + +void +jit_generator::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +jit_generator::visit_cell (tree_cell&) +{ + fail (); +} + +void +jit_generator::visit_multi_assignment (tree_multi_assignment&) +{ + fail (); +} + +void +jit_generator::visit_no_op_command (tree_no_op_command&) +{ + fail (); +} + +void +jit_generator::visit_constant (tree_constant& tc) +{ + octave_value v = tc.rvalue1 (); + llvm::LLVMContext& ctx = llvm::getGlobalContext (); + if (v.is_real_scalar () && v.is_double_type ()) + { + double dv = v.double_value (); + llvm::Value *lv = llvm::ConstantFP::get (ctx, llvm::APFloat (dv)); + push_value (tinfo->get_scalar (), lv); + } + else if (v.is_range ()) + { + Range rng = v.range_value (); + llvm::Type *range = tinfo->get_range_llvm (); + llvm::Type *scalar = tinfo->get_scalar_llvm (); + llvm::Type *index = tinfo->get_index_llvm (); + + std::vector values (4); + values[0] = llvm::ConstantFP::get (scalar, rng.base ()); + values[1] = llvm::ConstantFP::get (scalar, rng.limit ()); + values[2] = llvm::ConstantFP::get (scalar, rng.inc ()); + values[3] = llvm::ConstantInt::get (index, rng.nelem ()); + + llvm::StructType *llvm_range = llvm::cast(range); + llvm::Value *lv = llvm::ConstantStruct::get (llvm_range, values); + push_value (tinfo->get_range (), lv); + } + else + fail (); +} + +void +jit_generator::visit_fcn_handle (tree_fcn_handle&) +{ + fail (); +} + +void +jit_generator::visit_parameter_list (tree_parameter_list&) +{ + fail (); +} + +void +jit_generator::visit_postfix_expression (tree_postfix_expression&) +{ + fail (); +} + +void +jit_generator::visit_prefix_expression (tree_prefix_expression&) +{ + fail (); +} + +void +jit_generator::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +jit_generator::visit_return_list (tree_return_list&) +{ + fail (); +} + +void +jit_generator::visit_simple_assignment (tree_simple_assignment& tsa) +{ + if (is_lvalue) + fail (); + + // resolve rhs + tree_expression *rhs = tsa.right_hand_side (); + rhs->accept (*this); + + value rhsv = value_stack.back (); + value_stack.pop_back (); + + // resolve lhs + is_lvalue = true; + tree_expression *lhs = tsa.left_hand_side (); + lhs->accept (*this); + is_lvalue = false; + + value lhsv = value_stack.back (); + value_stack.pop_back (); + + // do assign, then keep rhs as the result + builder.CreateStore (rhsv.second, lhsv.second); + + if (tsa.print_result ()) + emit_print (lhs->name (), rhsv); + + value_stack.push_back (rhsv); +} + +void +jit_generator::visit_statement (tree_statement& stmt) +{ + tree_command *cmd = stmt.command (); + tree_expression *expr = stmt.expression (); + + if (cmd) + cmd->accept (*this); + else + { + // stolen from tree_evaluator::visit_statement + bool do_bind_ans = false; + + if (expr->is_identifier ()) + { + tree_identifier *id = dynamic_cast (expr); + + do_bind_ans = (! id->is_variable ()); + } + else + do_bind_ans = (! expr->is_assignment_expression ()); + + expr->accept (*this); + + if (do_bind_ans) + { + value rhs = value_stack.back (); + value ans = variables["ans"]; + if (ans.first != rhs.first) + fail (); + + builder.CreateStore (rhs.second, ans.second); + + if (expr->print_result ()) + emit_print ("ans", rhs); + } + else if (expr->is_identifier () && expr->print_result ()) + { + // FIXME: ugly hack, we need to come up with a way to pass + // nargout to visit_identifier + emit_print (expr->name (), value_stack.back ()); + } + + + value_stack.pop_back (); + } +} + +void +jit_generator::visit_statement_list (tree_statement_list& lst) +{ + tree_statement_list::iterator iter; + for (iter = lst.begin (); iter != lst.end (); ++iter) + { + tree_statement *stmt = *iter; + assert (stmt); // FIXME: jwe can this be null? + stmt->accept (*this); + } +} + +void +jit_generator::visit_switch_case (tree_switch_case&) +{ + fail (); +} + +void +jit_generator::visit_switch_case_list (tree_switch_case_list&) +{ + fail (); +} + +void +jit_generator::visit_switch_command (tree_switch_command&) +{ + fail (); +} + +void +jit_generator::visit_try_catch_command (tree_try_catch_command&) +{ + fail (); +} + +void +jit_generator::visit_unwind_protect_command (tree_unwind_protect_command&) +{ + fail (); +} + +void +jit_generator::visit_while_command (tree_while_command&) +{ + fail (); +} + +void +jit_generator::visit_do_until_command (tree_do_until_command&) +{ + fail (); +} + +void +jit_generator::emit_simple_for (tree_simple_for_command& cmd, value over, + bool atleast_once) +{ + if (is_lvalue) + fail (); + + jit_type *index = tinfo->get_index (); + llvm::Value *init_index = 0; + if (over.first == tinfo->get_range ()) + init_index = llvm::ConstantInt::get (index->to_llvm (), 0); + else + fail (); + + llvm::Value *llvm_index = builder.CreateAlloca (index->to_llvm (), 0, "index"); + builder.CreateStore (init_index, llvm_index); + + // FIXME: Support break + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "for_body", function); + llvm::BasicBlock *cond_check = llvm::BasicBlock::Create (ctx, "for_check", function); + llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "for_tail", function); + + // initialize the iter from the index + if (atleast_once) + builder.CreateBr (body); + else + builder.CreateBr (cond_check); + + builder.SetInsertPoint (body); + + is_lvalue = true; + tree_expression *lhs = cmd.left_hand_side (); + lhs->accept (*this); + is_lvalue = false; + + value lhsv = value_stack.back (); + value_stack.pop_back (); + + const jit_function::overload& index_ol = tinfo->get_simple_for_index (over.first); + llvm::Value *lindex = builder.CreateLoad (llvm_index); + llvm::Value *llvm_iter = builder.CreateCall2 (index_ol.function, over.second, lindex); + value iter(index_ol.result, llvm_iter); + builder.CreateStore (iter.second, lhsv.second); + + tree_statement_list *lst = cmd.body (); + lst->accept (*this); + + llvm::Value *one = llvm::ConstantInt::get (index->to_llvm (), 1); + lindex = builder.CreateLoad (llvm_index); + lindex = builder.CreateAdd (lindex, one); + builder.CreateStore (lindex, llvm_index); + builder.CreateBr (cond_check); + + builder.SetInsertPoint (cond_check); + lindex = builder.CreateLoad (llvm_index); + const jit_function::overload& check_ol = tinfo->get_simple_for_check (over.first); + llvm::Value *cond = builder.CreateCall2 (check_ol.function, over.second, lindex); + builder.CreateCondBr (cond, body, tail); + + builder.SetInsertPoint (tail); +} + +void +jit_generator::emit_print (const std::string& name, const value& v) +{ + const jit_function::overload& ol = tinfo->print_value (v.first); + if (! ol.function) + fail (); + + llvm::Value *str = builder.CreateGlobalStringPtr (name); + builder.CreateCall2 (ol.function, str, v.second); +} + +void +jit_generator::initialize (const std::vector& names, + const std::vector& argin, + const std::vector types) +{ + std::vector arg_types (names.size ()); + for (size_t i = 0; i < types.size (); ++i) + arg_types[i] = types[i]->to_llvm_arg (); + + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); + function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "foobar", module); + + // create variables and copy initial values + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); + builder.SetInsertPoint (body); + llvm::Function::arg_iterator arg_iter = function->arg_begin(); + for (size_t i = 0; i < names.size (); ++i, ++arg_iter) + { + llvm::Type *vartype = types[i]->to_llvm (); + const std::string& name = names[i]; + llvm::Value *var = builder.CreateAlloca (vartype, 0, name); + variables[name] = value (types[i], var); + + if (argin[i] || types[i]->force_init ()) + { + llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); + builder.CreateStore (loaded_arg, var); + } + } +} + +void +jit_generator::finalize (const std::vector& names) +{ + // copy computed values back into arguments + // we use names instead of looping through variables because order is + // important + llvm::Function::arg_iterator arg_iter = function->arg_begin(); + for (size_t i = 0; i < names.size (); ++i, ++arg_iter) + { + llvm::Value *var = variables[names[i]].second; + llvm::Value *loaded_var = builder.CreateLoad (var); + builder.CreateStore (loaded_var, arg_iter); + } + builder.CreateRetVoid (); +} + +// -------------------- tree_jit -------------------- + +tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) +{ + llvm::InitializeNativeTarget (); + module = new llvm::Module ("octave", context); +} + +tree_jit::~tree_jit (void) +{ + delete tinfo; +} + +bool +tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) +{ + if (! initialize ()) + return false; + + jit_type *bounds_t = tinfo->type_of (bounds); + jit_info *jinfo = cmd.get_info (bounds_t); + if (! jinfo) + { + jinfo = new jit_info (*this, cmd, bounds_t); + cmd.stash_info (bounds_t, jinfo); + } + + return jinfo->execute (bounds); +} + +bool +tree_jit::initialize (void) +{ + if (engine) + return true; + + // sometimes this fails pre main + engine = llvm::ExecutionEngine::createJIT (module); + + if (! engine) + return false; + + module_pass_manager = new llvm::PassManager (); + module_pass_manager->add (llvm::createAlwaysInlinerPass ()); + + pass_manager = new llvm::FunctionPassManager (module); + pass_manager->add (new llvm::TargetData(*engine->getTargetData ())); + pass_manager->add (llvm::createBasicAliasAnalysisPass ()); + pass_manager->add (llvm::createPromoteMemoryToRegisterPass ()); + pass_manager->add (llvm::createInstructionCombiningPass ()); + pass_manager->add (llvm::createReassociatePass ()); + pass_manager->add (llvm::createGVNPass ()); + pass_manager->add (llvm::createCFGSimplificationPass ()); + pass_manager->doInitialization (); + + tinfo = new jit_typeinfo (module, engine); + + return true; +} + + +void +tree_jit::optimize (llvm::Function *fn) +{ + module_pass_manager->run (*module); + pass_manager->run (*fn); +} + +// -------------------- jit_info -------------------- +jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd, + jit_type *bounds) : tinfo (tjit.get_typeinfo ()), + engine (tjit.get_engine ()), + bounds_t (bounds) +{ + jit_infer infer(tinfo); + + try + { + infer.infer (cmd, bounds); + } + catch (const jit_fail_exception&) + { + function = 0; + return; + } + + types = infer.get_types (); + + jit_generator gen(tinfo, tjit.get_module (), cmd, bounds, types); + function = gen.get_function (); + + if (function) + { + if (debug_print) + { + std::cout << "Compiled code:\n"; + std::cout << cmd.str_print_code () << std::endl; + + std::cout << "Before optimization:\n"; + + llvm::raw_os_ostream os (std::cout); + function->print (os); + } + llvm::verifyFunction (*function); + tjit.optimize (function); + + if (debug_print) + { + std::cout << "After optimization:\n"; + + llvm::raw_os_ostream os (std::cout); + function->print (os); + } + } +} + +bool +jit_info::execute (const octave_value& bounds) const +{ + if (! function) + return false; + + std::vector args (types.size () + 1); + tinfo->to_generic (bounds_t, args[0], bounds); + + size_t idx; + type_map::const_iterator iter; + for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) + { + if (iter->second.first) // argin? + { + octave_value ov = iter->first->varval (); + tinfo->to_generic (iter->second.second, args[idx], ov); + } + else + tinfo->to_generic (iter->second.second, args[idx]); + } + + engine->runFunction (function, args); + + for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) + { + octave_value result = tinfo->to_octave_value (iter->second.second, args[idx]); + octave_value &ref = iter->first->varref (); + ref = result; + } + + tinfo->reset_generic (); + + return true; +} + +bool +jit_info::match () const +{ + for (type_map::const_iterator iter = types.begin (); iter != types.end (); + ++iter) + + { + if (iter->second.first) // argin? + { + jit_type *required_type = iter->second.second; + octave_value val = iter->first->varval (); + jit_type *current_type = tinfo->type_of (val); + + // FIXME: should be: ! required_type->is_parent (current_type) + if (required_type != current_type) + return false; + } + } + + return true; +} diff -r 757f729fd41d -r 0b0569667939 src/pt-jit.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/pt-jit.h Wed May 23 16:22:05 2012 -0400 @@ -0,0 +1,694 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#if !defined (octave_tree_jit_h) +#define octave_tree_jit_h 1 + +#include +#include +#include +#include +#include + +#include "Array.h" +#include "Range.h" +#include "pt-walk.h" +#include "symtab.h" + +// -------------------- Current status -------------------- +// Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized. +// However, there is no warning emitted on divide by 0. For example, +// a = 5; +// b = a * 5 + a; +// +// For other types all binary operations are compiled but not optimized. For +// example, +// a = [1 2 3] +// b = a + a; +// will compile to do_binary_op (a, a). +// +// for loops with ranges compile. For example, +// for i=1:1000 +// result = i + 1; +// endfor +// Will compile. Nested for loops with constant bounds are also supported. +// +// If statements/comparisons compile, but && and || do not. +// +// TODO: +// 1. Support iteration over matricies +// 2. Check error state +// 3. ... +// --------------------------------------------------------- + + +// we don't want to include llvm headers here, as they require __STDC_LIMIT_MACROS +// and __STDC_CONSTANT_MACROS be defined in the entire compilation unit +namespace llvm +{ + class Value; + class Module; + class FunctionPassManager; + class PassManager; + class ExecutionEngine; + class Function; + class BasicBlock; + class LLVMContext; + class Type; + class GenericValue; + class Twine; +} + +class octave_base_value; +class octave_value; +class tree; + +// jit_range is compatable with the llvm range structure +struct +jit_range +{ + jit_range (void) {} + + jit_range (const Range& from) : base (from.base ()), limit (from.limit ()), + inc (from.inc ()), nelem (from.nelem ()) + {} + + operator Range () const + { + return Range (base, limit, inc); + } + + double base; + double limit; + double inc; + octave_idx_type nelem; +}; + +// Used to keep track of estimated (infered) types during JIT. This is a +// hierarchical type system which includes both concrete and abstract types. +// +// Current, we only support any and scalar types. If we can't figure out what +// type a variable is, we assign it the any type. This allows us to generate +// code even for the case of poor type inference. +class +jit_type +{ +public: + jit_type (const std::string& n, bool fi, jit_type *mparent, llvm::Type *lt, + int tid) : + mname (n), finit (fi), p (mparent), llvm_type (lt), id (tid) + {} + + // a user readable type name + const std::string& name (void) const { return mname; } + + // do we need to initialize variables of this type, even if they are not + // input arguments? + bool force_init (void) const { return finit; } + + // a unique id for the type + int type_id (void) const { return id; } + + // An abstract base type, may be null + jit_type *parent (void) const { return p; } + + // convert to an llvm type + llvm::Type *to_llvm (void) const { return llvm_type; } + + // how this type gets passed as a function argument + llvm::Type *to_llvm_arg (void) const; +private: + std::string mname; + bool finit; + jit_type *p; + llvm::Type *llvm_type; + int id; +}; + +// Keeps track of overloads for a builtin function. Used for both type inference +// and code generation. +class +jit_function +{ +public: + struct overload + { + overload (void) : function (0), can_error (true), result (0) {} + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) : + function (f), can_error (e), result (r), arguments (1) + { + arguments[0] = arg0; + } + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) : function (f), can_error (e), result (r), + arguments (2) + { + arguments[0] = arg0; + arguments[1] = arg1; + } + + llvm::Function *function; + bool can_error; + jit_type *result; + std::vector arguments; + }; + + void add_overload (const overload& func) + { + add_overload (func, func.arguments); + } + + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) + { + overload ol (f, e, r, arg0); + add_overload (ol); + } + + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) + { + overload ol (f, e, r, arg0, arg1); + add_overload (ol); + } + + void add_overload (const overload& func, + const std::vector& args); + + const overload& get_overload (const std::vector& types) const; + + const overload& get_overload (jit_type *arg0) const + { + std::vector types (1); + types[0] = arg0; + return get_overload (types); + } + + const overload& get_overload (jit_type *arg0, jit_type *arg1) const + { + std::vector types (2); + types[0] = arg0; + types[1] = arg1; + return get_overload (types); + } + + jit_type *get_result (const std::vector& types) const + { + const overload& temp = get_overload (types); + return temp.result; + } + + jit_type *get_result (jit_type *arg0, jit_type *arg1) const + { + const overload& temp = get_overload (arg0, arg1); + return temp.result; + } +private: + Array to_idx (const std::vector& types) const; + + std::vector > overloads; +}; + +// Get information and manipulate jit types. +class +jit_typeinfo +{ +public: + jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); + + jit_type *get_any (void) const { return any; } + + jit_type *get_scalar (void) const { return scalar; } + + llvm::Type *get_scalar_llvm (void) const { return scalar->to_llvm (); } + + jit_type *get_range (void) const { return range; } + + llvm::Type *get_range_llvm (void) const { return range->to_llvm (); } + + jit_type *get_bool (void) const { return boolean; } + + jit_type *get_index (void) const { return index; } + + llvm::Type *get_index_llvm (void) const { return index->to_llvm (); } + + jit_type *type_of (const octave_value& ov) const; + + const jit_function& binary_op (int op) const; + + const jit_function::overload& binary_op_overload (int op, jit_type *lhs, + jit_type *rhs) const + { + const jit_function& jf = binary_op (op); + return jf.get_overload (lhs, rhs); + } + + jit_type *binary_op_result (int op, jit_type *lhs, jit_type *rhs) const + { + const jit_function::overload& ol = binary_op_overload (op, lhs, rhs); + return ol.result; + } + + const jit_function::overload& grab (jit_type *ty) const + { + return grab_fn.get_overload (ty); + } + + const jit_function::overload& release (jit_type *ty) const + { + return release_fn.get_overload (ty); + } + + const jit_function::overload& print_value (jit_type *to_print) const; + + const jit_function::overload& get_simple_for_check (jit_type *bounds) const + { + return simple_for_check.get_overload (bounds, index); + } + + const jit_function::overload& get_simple_for_index (jit_type *bounds) const + { + return simple_for_index.get_overload (bounds, index); + } + + jit_type *get_simple_for_index_result (jit_type *bounds) const + { + const jit_function::overload& ol = get_simple_for_index (bounds); + return ol.result; + } + + const jit_function::overload& get_logically_true (jit_type *conv) const + { + return logically_true.get_overload (conv); + } + + void to_generic (jit_type *type, llvm::GenericValue& gv); + + void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov); + + octave_value to_octave_value (jit_type *type, llvm::GenericValue& gv); + + void reset_generic (void); +private: + jit_type *new_type (const std::string& name, bool force_init, + jit_type *parent, llvm::Type *llvm_type); + + void add_print (jit_type *ty, void *call); + + void add_binary_op (jit_type *ty, int op, int llvm_op); + + void add_binary_icmp (jit_type *ty, int op, int llvm_op); + + void add_binary_fcmp (jit_type *ty, int op, int llvm_op); + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + llvm::Type *arg0) + { + std::vector args (1, arg0); + return create_function (name, ret, args); + } + + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + jit_type *arg0) + { + return create_function (name, ret->to_llvm (), arg0->to_llvm ()); + } + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + llvm::Type *arg0, llvm::Type *arg1) + { + std::vector args (2); + args[0] = arg0; + args[1] = arg1; + return create_function (name, ret, args); + } + + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + jit_type *arg0, jit_type *arg1) + { + return create_function (name, ret->to_llvm (), arg0->to_llvm (), + arg1->to_llvm ()); + } + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + llvm::Type *arg0, llvm::Type *arg1, + llvm::Type *arg2) + { + std::vector args (3); + args[0] = arg0; + args[1] = arg1; + args[2] = arg2; + return create_function (name, ret, args); + } + + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + jit_type *arg0, jit_type *arg1, + jit_type *arg2) + { + return create_function (name, ret->to_llvm (), arg0->to_llvm (), + arg1->to_llvm (), arg2->to_llvm ()); + } + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + const std::vector& args); + + llvm::Module *module; + llvm::ExecutionEngine *engine; + int next_id; + + llvm::Type *ov_t; + + std::vector id_to_type; + jit_type *any; + jit_type *scalar; + jit_type *range; + jit_type *boolean; + jit_type *index; + + std::vector binary_ops; + jit_function grab_fn; + jit_function release_fn; + jit_function print_fn; + jit_function simple_for_check; + jit_function simple_for_incr; + jit_function simple_for_index; + jit_function logically_true; + + std::list scalar_out; + std::list ov_out; + std::list range_out; +}; + +class +jit_infer : public tree_walker +{ +public: + // pair + typedef std::pair type_entry; + typedef std::map type_map; + + jit_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false), + rvalue_type (0) + {} + + const type_map& get_types () const { return types; } + + void infer (tree_simple_for_command& cmd, jit_type *bounds); + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); +private: + void infer_simple_for (tree_simple_for_command& cmd, + jit_type *bounds); + + void handle_identifier (const symbol_table::symbol_record_ref& record); + + void merge (type_map& dest, const type_map& src); + + jit_typeinfo *tinfo; + + bool is_lvalue; + jit_type *rvalue_type; + + type_map types; + + std::vector type_stack; +}; + +class +jit_generator : public tree_walker +{ +public: + typedef jit_infer::type_map type_map; + + jit_generator (jit_typeinfo *ti, llvm::Module *mod, tree_simple_for_command &cmd, + jit_type *bounds, const type_map& infered_types); + + llvm::Function *get_function () const { return function; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); +private: + typedef std::pair value; + + void emit_simple_for (tree_simple_for_command& cmd, value over, + bool atleast_once); + + void emit_print (const std::string& name, const value& v); + + void push_value (jit_type *type, llvm::Value *v) + { + value_stack.push_back (value (type, v)); + } + + void initialize (const std::vector& names, + const std::vector& argin, + const std::vector types); + + void finalize (const std::vector& names); + + jit_typeinfo *tinfo; + llvm::Module *module; + llvm::Function *function; + + bool is_lvalue; + std::map variables; + std::vector value_stack; +}; + +class +tree_jit +{ +public: + tree_jit (void); + + ~tree_jit (void); + + bool execute (tree_simple_for_command& cmd, const octave_value& bounds); + + jit_typeinfo *get_typeinfo (void) const { return tinfo; } + + llvm::ExecutionEngine *get_engine (void) const { return engine; } + + llvm::Module *get_module (void) const { return module; } + + void optimize (llvm::Function *fn); + private: + bool initialize (void); + + llvm::LLVMContext &context; + llvm::Module *module; + llvm::PassManager *module_pass_manager; + llvm::FunctionPassManager *pass_manager; + llvm::ExecutionEngine *engine; + + jit_typeinfo *tinfo; +}; + +class +jit_info +{ +public: + typedef jit_infer::type_map type_map; + + jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds); + + bool execute (const octave_value& bounds) const; + + bool match (void) const; +private: + jit_typeinfo *tinfo; + llvm::ExecutionEngine *engine; + type_map types; + llvm::Function *function; + jit_type *bounds_t; +}; + +#endif diff -r 757f729fd41d -r 0b0569667939 src/pt-loop.cc --- a/src/pt-loop.cc Wed May 23 13:36:24 2012 -0400 +++ b/src/pt-loop.cc Wed May 23 16:22:05 2012 -0400 @@ -35,6 +35,7 @@ #include "pt-bp.h" #include "pt-cmd.h" #include "pt-exp.h" +#include "pt-jit.h" #include "pt-jump.h" #include "pt-loop.h" #include "pt-stmt.h" @@ -97,6 +98,10 @@ delete list; delete lead_comm; delete trail_comm; + + for (compiled_map::iterator iter = compiled.begin (); iter != compiled.end (); + ++iter) + delete iter->second; } tree_command * diff -r 757f729fd41d -r 0b0569667939 src/pt-loop.h --- a/src/pt-loop.h Wed May 23 13:36:24 2012 -0400 +++ b/src/pt-loop.h Wed May 23 16:22:05 2012 -0400 @@ -36,6 +36,9 @@ #include "pt-cmd.h" #include "symtab.h" +class jit_info; +class jit_type; + // While. class @@ -180,7 +183,20 @@ void accept (tree_walker& tw); + // some functions use by tree_jit + jit_info *get_info (jit_type *type) const + { + compiled_map::const_iterator iter = compiled.find (type); + return iter != compiled.end () ? iter->second : 0; + } + + void stash_info (jit_type *type, jit_info *jinfo) + { + compiled[type] = jinfo; + } + private: + typedef std::map compiled_map; // TRUE means operate in parallel (subject to the value of the // maxproc expression). @@ -205,6 +221,9 @@ // Comment preceding ENDFOR token. octave_comment_list *trail_comm; + // a map from iterator types -> compiled functions + compiled_map compiled; + // No copying! tree_simple_for_command (const tree_simple_for_command&); diff -r 757f729fd41d -r 0b0569667939 src/pt-stmt.h --- a/src/pt-stmt.h Wed May 23 13:36:24 2012 -0400 +++ b/src/pt-stmt.h Wed May 23 16:22:05 2012 -0400 @@ -35,12 +35,13 @@ #include "base-list.h" #include "comment-list.h" #include "symtab.h" +#include "pt.h" // A statement is either a command to execute or an expression to // evaluate. class -tree_statement +tree_statement : public tree { public: diff -r 757f729fd41d -r 0b0569667939 src/symtab.h --- a/src/symtab.h Wed May 23 13:36:24 2012 -0400 +++ b/src/symtab.h Wed May 23 16:22:05 2012 -0400 @@ -484,7 +484,7 @@ return symbol_record (rep->dup (new_scope)); } - std::string name (void) const { return rep->name; } + const std::string& name (void) const { return rep->name; } octave_value find (const octave_value_list& args = octave_value_list ()) const; @@ -581,6 +581,66 @@ symbol_record (symbol_record_rep *new_rep) : rep (new_rep) { } }; + // Always access a symbol from the current scope. + // Useful for scripts, as they may be executed in more than one scope. + class + symbol_record_ref + { + public: + symbol_record_ref (void) : scope (-1) {} + + symbol_record_ref (symbol_record record, + scope_id curr_scope = symbol_table::current_scope ()) + : scope (curr_scope), sym (record) + {} + + symbol_record_ref& operator = (const symbol_record_ref& ref) + { + scope = ref.scope; + sym = ref.sym; + return *this; + } + + // The name is the same regardless of scope. + const std::string& name (void) const { return sym.name (); } + + symbol_record *operator-> (void) + { + update (); + return &sym; + } + + symbol_record *operator-> (void) const + { + update (); + return &sym; + } + + // can be used to place symbol_record_ref in maps, we don't overload < as + // it doesn't make any sense for symbol_record_ref + struct comparator + { + bool operator ()(const symbol_record_ref& lhs, + const symbol_record_ref& rhs) const + { + return lhs.name () < rhs.name (); + } + }; + private: + void update (void) const + { + scope_id curr_scope = symbol_table::current_scope (); + if (scope != curr_scope || ! sym.is_valid ()) + { + scope = curr_scope; + sym = symbol_table::insert (sym.name ()); + } + } + + mutable scope_id scope; + mutable symbol_record sym; + }; + class fcn_info { diff -r 757f729fd41d -r 0b0569667939 src/toplev.cc --- a/src/toplev.cc Wed May 23 13:36:24 2012 -0400 +++ b/src/toplev.cc Wed May 23 16:22:05 2012 -0400 @@ -1325,6 +1325,9 @@ { false, "MAGICK_CPPFLAGS", OCTAVE_CONF_MAGICK_CPPFLAGS }, { false, "MAGICK_LDFLAGS", OCTAVE_CONF_MAGICK_LDFLAGS }, { false, "MAGICK_LIBS", OCTAVE_CONF_MAGICK_LIBS }, + { false, "LLVM_CPPFLAGS", OCTAVE_CONF_LLVM_CPPFLAGS }, + { false, "LLVM_LDFLAGS", OCTAVE_CONF_LLVM_LDFLAGS }, + { false, "LLVM_LIBS", OCTAVE_CONF_LLVM_LIBS }, { false, "MKOCTFILE_DL_LDFLAGS", OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS }, { false, "OCTAVE_LINK_DEPS", OCTAVE_CONF_OCTAVE_LINK_DEPS }, { false, "OCTAVE_LINK_OPTS", OCTAVE_CONF_OCTAVE_LINK_OPTS },