From a608da2103e20a0c0a39bebbd50692865d798ee7 Mon Sep 17 00:00:00 2001 From: oli Date: Thu, 8 Apr 2021 15:04:04 +0000 Subject: [PATCH 1/4] towards compiling gnur bc to pir --- rir/src/api.cpp | 33 +- rir/src/compiler/backend.cpp | 5 +- rir/src/compiler/compiler.cpp | 24 +- rir/src/compiler/compiler.h | 7 +- rir/src/compiler/gnur2pir/gnur2pir.cpp | 468 +++++++++++++++++++++ rir/src/compiler/gnur2pir/gnur2pir.h | 24 ++ rir/src/compiler/log/stream_logger.cpp | 3 +- rir/src/compiler/native/pir_jit_llvm.h | 8 +- rir/src/compiler/opt/eager_calls.cpp | 9 +- rir/src/compiler/opt/force_dominance.cpp | 3 +- rir/src/compiler/opt/inline.cpp | 44 +- rir/src/compiler/opt/inline_force_prom.cpp | 50 ++- rir/src/compiler/opt/promise_splitter.cpp | 3 +- rir/src/compiler/pir/builder.cpp | 18 +- rir/src/compiler/pir/closure.cpp | 22 +- rir/src/compiler/pir/closure.h | 9 +- rir/src/compiler/pir/closure_version.cpp | 12 +- rir/src/compiler/pir/closure_version.h | 19 +- rir/src/compiler/pir/code.h | 31 +- rir/src/compiler/pir/instruction.cpp | 6 +- rir/src/compiler/pir/instruction.h | 3 +- rir/src/compiler/pir/module.cpp | 12 + rir/src/compiler/pir/module.h | 4 +- rir/src/compiler/pir/promise.cpp | 15 +- rir/src/compiler/pir/promise.h | 17 +- rir/src/compiler/pir/type.h | 6 +- rir/src/compiler/rir2pir/rir2pir.cpp | 49 ++- rir/src/compiler/rir2pir/rir2pir.h | 13 +- rir/src/compiler/test/PirTests.cpp | 3 +- rir/src/compiler/util/bb_transform.cpp | 2 +- rir/src/runtime/Code.cpp | 2 +- rir/src/runtime/Code.h | 2 +- rir/src/utils/FormalArgs.h | 5 + 33 files changed, 737 insertions(+), 194 deletions(-) create mode 100644 rir/src/compiler/gnur2pir/gnur2pir.cpp create mode 100644 rir/src/compiler/gnur2pir/gnur2pir.h diff --git a/rir/src/api.cpp b/rir/src/api.cpp index 8e6d0ee75..b4cd78971 100644 --- a/rir/src/api.cpp +++ b/rir/src/api.cpp @@ -7,6 +7,7 @@ #include "R/Serialize.h" #include "compiler/backend.h" #include "compiler/compiler.h" +#include "compiler/gnur2pir/gnur2pir.h" #include "compiler/log/debug.h" #include "compiler/parameter.h" #include "compiler/test/PirCheck.h" @@ -286,15 +287,7 @@ REXPORT SEXP pirSetDebugFlags(SEXP debugFlags) { SEXP pirCompile(SEXP what, const Context& assumptions, const std::string& name, const pir::DebugOptions& debug) { - if (!isValidClosureSEXP(what)) { - Rf_error("not a compiled closure"); - } - if (!DispatchTable::check(BODY(what))) { - Rf_error("Cannot optimize compiled expression, only closure"); - } - PROTECT(what); - bool dryRun = debug.includes(pir::DebugFlag::DryRun); // compile to pir pir::Module* m = new pir::Module; @@ -365,6 +358,30 @@ REXPORT SEXP pirCompileWrapper(SEXP what, SEXP name, SEXP debugFlags, return pirCompile(what, rir::pir::Compiler::defaultContext, n, opts); } +REXPORT SEXP gnur2pir(SEXP what) { + // compile to pir + pir::StreamLogger logger(PirDebug); + pir::Module m; + pir::Gnur2Pir g2p(m); + pir::ClosureVersion* c = g2p.compile(what, ""); + if (!c) + return R_NilValue; + + c->printCode(std::cout, true, false); + pir::Compiler cmp(&m, logger); + pir::Backend backend(logger, ""); + auto f = backend.getOrCompile(c); + + // TODO the rest of the system really wants the dispatch table to only + // contain rir::Functions... So we need to compile the baseline anyway to be + // able to call the opt version... + what = rirCompile(what, nullptr); + auto dt = DispatchTable::unpack(BODY(what)); + dt->insert(f); + + return what; +} + REXPORT SEXP pirTests() { PirTests::run(); return R_NilValue; diff --git a/rir/src/compiler/backend.cpp b/rir/src/compiler/backend.cpp index 4596b1a41..217f531e0 100644 --- a/rir/src/compiler/backend.cpp +++ b/rir/src/compiler/backend.cpp @@ -348,7 +348,7 @@ rir::Function* Backend::doCompile(ClosureVersion* cls, approximateRefcount(cls, c, refcount, log); std::unordered_set needsLdVarForUpdate; approximateNeedsLdVarForUpdate(c, needsLdVarForUpdate); - auto res = done[c] = rir::Code::New(c->rirSrc()->src); + auto res = done[c] = rir::Code::New(c->expression()); // Can we do better? preserve(res->container()); jit.compile(res, c, promMap.at(c), refcount, needsLdVarForUpdate, log); @@ -375,7 +375,8 @@ rir::Function* Backend::doCompile(ClosureVersion* cls, log.finalPIR(cls); function.finalize(body, signature, cls->context()); - function.function()->inheritFlags(cls->owner()->rirFunction()); + cls->owner()->rirFunction( + [&](rir::Function* f) { function.function()->inheritFlags(f); }); return function.function(); } diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index 79812cbf4..60fd95db3 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -52,8 +52,8 @@ void Compiler::compileClosure(SEXP closure, const std::string& name, auto pirClosure = module->getOrDeclareRirClosure(closureName, closure, fun, tbl->userDefinedContext()); Context context(assumptions); - compileClosure(pirClosure, tbl->dispatch(assumptions), context, root, - success, fail, outerFeedback); + compileClosure(pirClosure, tbl->baseline(), tbl->dispatch(assumptions), + context, root, success, fail, outerFeedback); } void Compiler::compileFunction(rir::DispatchTable* src, const std::string& name, @@ -68,13 +68,13 @@ void Compiler::compileFunction(rir::DispatchTable* src, const std::string& name, Context context(assumptions); auto closure = module->getOrDeclareRirFunction( name, srcFunction, formals, srcRef, src->userDefinedContext()); - compileClosure(closure, src->dispatch(assumptions), context, false, success, - fail, outerFeedback); + compileClosure(closure, srcFunction, src->dispatch(assumptions), context, + false, success, fail, outerFeedback); } -void Compiler::compileClosure(Closure* closure, rir::Function* optFunction, - const Context& ctx, bool root, MaybeCls success, - Maybe fail, +void Compiler::compileClosure(Closure* closure, rir::Function* srcCode, + rir::Function* optFunction, const Context& ctx, + bool root, MaybeCls success, Maybe fail, std::list outerFeedback) { if (!ctx.includes(minimalContext)) { @@ -99,8 +99,10 @@ void Compiler::compileClosure(Closure* closure, rir::Function* optFunction, return fail(); } - if (closure->rirFunction()->body()->codeSize > Parameter::MAX_INPUT_SIZE) { - closure->rirFunction()->flags.set(Function::NotOptimizable); + auto sz = closure->bodySize(); + if (sz > Parameter::MAX_INPUT_SIZE) { + closure->rirFunction( + [&](rir::Function* f) { f->flags.set(Function::NotOptimizable); }); logger.warn("skipping huge function"); return fail(); } @@ -125,7 +127,7 @@ void Compiler::compileClosure(Closure* closure, rir::Function* optFunction, auto arg = closure->formals().defaultArgs()[idx]; assert(rir::Code::check(arg) && "Default arg not compiled"); auto code = rir::Code::unpack(arg); - auto res = rir2pir.tryCreateArg(code, builder, false); + auto res = rir2pir.tryCreateArg(code, builder, false, -1); if (!res) { failedToCompileDefaultArgs = true; return; @@ -186,7 +188,7 @@ void Compiler::compileClosure(Closure* closure, rir::Function* optFunction, return fail(); } - if (rir2pir.tryCompile(builder)) { + if (rir2pir.tryCompile(srcCode->body(), builder)) { log.compilationEarlyPir(version); #ifdef FULLVERIFIER Verify::apply(version, "Error after initial translation", true); diff --git a/rir/src/compiler/compiler.h b/rir/src/compiler/compiler.h index 28983adf4..c4c4c1c2d 100644 --- a/rir/src/compiler/compiler.h +++ b/rir/src/compiler/compiler.h @@ -46,9 +46,10 @@ class Compiler { Module* module; StreamLogger& logger; - void compileClosure(Closure* closure, rir::Function* optFunction, - const Context& ctx, bool root, MaybeCls success, - Maybe fail, std::list outerFeedback); + void compileClosure(Closure* closure, rir::Function* srcCode, + rir::Function* optFunction, const Context& ctx, + bool root, MaybeCls success, Maybe fail, + std::list outerFeedback); Preserve preserve_; }; diff --git a/rir/src/compiler/gnur2pir/gnur2pir.cpp b/rir/src/compiler/gnur2pir/gnur2pir.cpp new file mode 100644 index 000000000..1a5029543 --- /dev/null +++ b/rir/src/compiler/gnur2pir/gnur2pir.cpp @@ -0,0 +1,468 @@ +#include "gnur2pir.h" +#include "../rir2pir/insert_cast.h" +#include "R/BuiltinIds.h" +#include "R/Funtab.h" +#include "R/RList.h" +#include "R/Symbols.h" +#include "compiler/analysis/cfg.h" +#include "compiler/analysis/query.h" +#include "compiler/analysis/verifier.h" +#include "compiler/opt/pass_definitions.h" +#include "compiler/pir/builder.h" +#include "compiler/pir/pir_impl.h" +#include "compiler/util/arg_match.h" +#include "compiler/util/visitor.h" +#include "ir/BC.h" +#include "ir/Compiler.h" +#include "runtime/ArglistOrder.h" +#include "simple_instruction_list.h" +#include "utils/FormalArgs.h" + +#include +#include +#include + +namespace rir { +namespace pir { + +typedef std::pair ReturnSite; + +struct State { + bool seen = false; + BB* entryBB = nullptr; + size_t entryPC = 0; + + State() {} + State(State&&) = default; + State(const State&) = delete; + State(const State& other, bool seen, BB* entryBB, size_t entryPC) + : seen(seen), entryBB(entryBB), entryPC(entryPC), stack(other.stack){}; + + void operator=(const State&) = delete; + State& operator=(State&&) = default; + + void mergeIn(const State& incom, BB* incomBB); + void createMergepoint(Builder&); + + void clear() { + stack.clear(); + entryBB = nullptr; + entryPC = 0; + } + + RirStack stack; +}; + +void State::createMergepoint(Builder& insert) { + BB* oldBB = insert.getCurrentBB(); + insert.createNextBB(); + for (size_t i = 0; i < stack.size(); ++i) { + auto v = stack.at(i); + auto p = insert(new Phi); + p->addInput(oldBB, v); + stack.at(i) = p; + } +} + +void State::mergeIn(const State& incom, BB* incomBB) { + assert(stack.size() == incom.stack.size()); + + for (size_t i = 0; i < stack.size(); ++i) { + Phi* p = Phi::Cast(stack.at(i)); + assert(p); + Value* in = incom.stack.at(i); + if (in != Tombstone::unreachable()) + p->addInput(incomBB, in); + } + incomBB->setNext(entryBB); +} + +#define GNUR_BC(V) \ + V(BCMISMATCH_OP, 0) \ + V(RETURN_OP, 0) \ + V(GOTO_OP, 1) \ + V(BRIFNOT_OP, 2) \ + V(POP_OP, 0) \ + V(DUP_OP, 0) \ + V(PRINTVALUE_OP, 0) \ + V(STARTLOOPCNTXT_OP, 2) \ + V(ENDLOOPCNTXT_OP, 1) \ + V(DOLOOPNEXT_OP, 0) \ + V(DOLOOPBREAK_OP, 0) \ + V(STARTFOR_OP, 3) \ + V(STEPFOR_OP, 1) \ + V(ENDFOR_OP, 0) \ + V(SETLOOPVAL_OP, 0) \ + V(INVISIBLE_OP, 0) \ + V(LDCONST_OP, 1) \ + V(LDNULL_OP, 0) \ + V(LDTRUE_OP, 0) \ + V(LDFALSE_OP, 0) \ + V(GETVAR_OP, 1) \ + V(DDVAL_OP, 1) \ + V(SETVAR_OP, 1) \ + V(GETFUN_OP, 1) \ + V(GETGLOBFUN_OP, 1) \ + V(GETSYMFUN_OP, 1) \ + V(GETBUILTIN_OP, 1) \ + V(GETINTLBUILTIN_OP, 1) \ + V(CHECKFUN_OP, 0) \ + V(MAKEPROM_OP, 1) \ + V(DOMISSING_OP, 0) \ + V(SETTAG_OP, 1) \ + V(DODOTS_OP, 0) \ + V(PUSHARG_OP, 0) \ + V(PUSHCONSTARG_OP, 1) \ + V(PUSHNULLARG_OP, 0) \ + V(PUSHTRUEARG_OP, 0) \ + V(PUSHFALSEARG_OP, 0) \ + V(CALL_OP, 1) \ + V(CALLBUILTIN_OP, 1) \ + V(CALLSPECIAL_OP, 1) \ + V(MAKECLOSURE_OP, 1) \ + V(UMINUS_OP, 1) \ + V(UPLUS_OP, 1) \ + V(ADD_OP, 1) \ + V(SUB_OP, 1) \ + V(MUL_OP, 1) \ + V(DIV_OP, 1) \ + V(EXPT_OP, 1) \ + V(SQRT_OP, 1) \ + V(EXP_OP, 1) \ + V(EQ_OP, 1) \ + V(NE_OP, 1) \ + V(LT_OP, 1) \ + V(LE_OP, 1) \ + V(GE_OP, 1) \ + V(GT_OP, 1) \ + V(AND_OP, 1) \ + V(OR_OP, 1) \ + V(NOT_OP, 1) \ + V(DOTSERR_OP, 0) \ + V(STARTASSIGN_OP, 1) \ + V(ENDASSIGN_OP, 1) \ + V(STARTSUBSET_OP, 2) \ + V(DFLTSUBSET_OP, 0) \ + V(STARTSUBASSIGN_OP, 2) \ + V(DFLTSUBASSIGN_OP, 0) \ + V(STARTC_OP, 2) \ + V(DFLTC_OP, 0) \ + V(STARTSUBSET2_OP, 2) \ + V(DFLTSUBSET2_OP, 0) \ + V(STARTSUBASSIGN2_OP, 2) \ + V(DFLTSUBASSIGN2_OP, 0) \ + V(DOLLAR_OP, 2) \ + V(DOLLARGETS_OP, 2) \ + V(ISNULL_OP, 0) \ + V(ISLOGICAL_OP, 0) \ + V(ISINTEGER_OP, 0) \ + V(ISDOUBLE_OP, 0) \ + V(ISCOMPLEX_OP, 0) \ + V(ISCHARACTER_OP, 0) \ + V(ISSYMBOL_OP, 0) \ + V(ISOBJECT_OP, 0) \ + V(ISNUMERIC_OP, 0) \ + V(VECSUBSET_OP, 1) \ + V(MATSUBSET_OP, 1) \ + V(VECSUBASSIGN_OP, 1) \ + V(MATSUBASSIGN_OP, 1) \ + V(AND1ST_OP, 2) \ + V(AND2ND_OP, 1) \ + V(OR1ST_OP, 2) \ + V(OR2ND_OP, 1) \ + V(GETVAR_MISSOK_OP, 1) \ + V(DDVAL_MISSOK_OP, 1) \ + V(VISIBLE_OP, 0) \ + V(SETVAR2_OP, 1) \ + V(STARTASSIGN2_OP, 1) \ + V(ENDASSIGN2_OP, 1) \ + V(SETTER_CALL_OP, 2) \ + V(GETTER_CALL_OP, 1) \ + V(SWAP_OP, 0) \ + V(DUP2ND_OP, 0) \ + V(SWITCH_OP, 4) \ + V(RETURNJMP_OP, 0) \ + V(STARTSUBSET_N_OP, 2) \ + V(STARTSUBASSIGN_N_OP, 2) \ + V(VECSUBSET2_OP, 1) \ + V(MATSUBSET2_OP, 1) \ + V(VECSUBASSIGN2_OP, 1) \ + V(MATSUBASSIGN2_OP, 1) \ + V(STARTSUBSET2_N_OP, 2) \ + V(STARTSUBASSIGN2_N_OP, 2) \ + V(SUBSET_N_OP, 2) \ + V(SUBSET2_N_OP, 2) \ + V(SUBASSIGN_N_OP, 2) \ + V(SUBASSIGN2_N_OP, 2) \ + V(LOG_OP, 1) \ + V(LOGBASE_OP, 1) \ + V(MATH1_OP, 2) \ + V(DOTCALL_OP, 2) \ + V(COLON_OP, 1) \ + V(SEQALONG_OP, 1) \ + V(SEQLEN_OP, 1) \ + V(BASEGUARD_OP, 2) \ + V(INCLNK_OP, 0) \ + V(DECLNK_OP, 0) \ + V(DECLNK_N_OP, 1) + +class RBC { + public: + enum Id { +#define V(BC, _) BC, + GNUR_BC(V) +#undef V + }; + + const Id id; + + bool valid() const { + switch (id) { +#define V(BC, I) case BC: + GNUR_BC(V) +#undef V + return true; + default: {} + } + return false; + } + + size_t imm() const { + switch (id) { +#define V(BC, I) \ + case BC: \ + return I; + GNUR_BC(V) +#undef V + } + assert(false); + return 0; + } + int imm(size_t pos) const { + assert(pos < parsed); + return args[pos]; + } + + explicit RBC(const int* pc) : id((Id)pc[0]) { + for (; parsed < imm(); ++parsed) + args[parsed] = pc[parsed + 1]; + assert(valid()); + }; + + bool operator==(int bc) const { return id == bc; } + + bool jumps() const { + return id == RBC::BRIFNOT_OP || id == RBC::GOTO_OP || + id == RBC::STARTFOR_OP || id == RBC::STEPFOR_OP; + } + + size_t jump() const { + if (id == RBC::BRIFNOT_OP) + return imm(1); + if (id == RBC::GOTO_OP) + return imm(0); + if (id == RBC::STARTFOR_OP) + return imm(2); + if (id == RBC::STEPFOR_OP) + return imm(0); + assert(false); + return 0; + } + + bool falls() const { + return id != RBC::RETURN_OP && id != RBC::GOTO_OP && + id != RBC::STARTFOR_OP; + } + + friend std::ostream& operator<<(std::ostream& out, Id id) { + switch (id) { +#define V(BC, _) \ + case BC: \ + out << #BC; \ + break; + GNUR_BC(V) +#undef V + } + return out; + } + + friend std::ostream& operator<<(std::ostream& out, RBC bc) { + out << bc.id; + if (bc.imm() > 0) { + out << "("; + for (unsigned i = 0; i < bc.imm(); ++i) { + out << bc.imm(i); + if (i < bc.imm() - 1) + out << ", "; + } + out << ")"; + } + return out; + } + + private: + unsigned parsed = 0; + std::array args; +}; + +class RBCCode { + private: + SEXP body; + Preserve p; + int* stream() { return INTEGER(body); } + const int* stream() const { return INTEGER(body); } + + public: + explicit RBCCode(SEXP body) : body(R_bcDecode(BCODE_CODE(body))) { + p(body); + } + size_t length() const { return XLENGTH(body); } + int version() const { return stream()[0]; } + + RBC operator[](size_t pos) const { return RBC(stream() + pos); } + + struct Iterator { + const int* pos; + explicit Iterator(const int* pos) : pos(pos) {} + + RBC operator*() { return RBC(pos); } + void operator++() { pos += this->operator*().imm() + 1; } + bool operator!=(const Iterator& other) const { + return pos != other.pos; + } + bool operator==(const Iterator& other) const { + return pos == other.pos; + } + + size_t label(const Iterator& begin) const { + return pos - begin.pos + 1; + } + Iterator operator+(size_t n) const { + auto pos = *this; + for (size_t i = 0; i < n; ++i) + ++pos; + return pos; + } + std::list successors() { + auto bc = this->operator*(); + if (bc.jumps() && bc.falls()) + return {*this + bc.jump(), *this + 1}; + if (bc.jumps()) + return {*this + bc.jump()}; + if (bc.falls()) + return {*this + 1}; + return {}; + } + }; + Iterator begin() const { return Iterator(stream() + 1); } + Iterator end() const { return Iterator(stream() + length()); } + + friend std::ostream& operator<<(std::ostream& out, const RBCCode& code) { + out << "Length : " << code.length() << "\n"; + out << "Version : " << code.version() << "\n"; + for (auto pc = code.begin(); pc != code.end(); ++pc) { + out << std::left << std::setw(6) << pc.label(code.begin()) << *pc + << "\n"; + } + return out; + } +}; + +struct CompilerInfo { + CompilerInfo(SEXP src) : src(src) {} + SEXP src; + std::unordered_set mergepoints; + std::unordered_map jumps; +}; + +static void findMerges(const RBCCode& bytecode, CompilerInfo& info) { + std::unordered_map> incom; + // Mark incoming jmps + for (auto pc = bytecode.begin(); pc != bytecode.end(); ++pc) { + auto l = pc.label(bytecode.begin()); + auto bc = *pc; + if (bc == RBC::BRIFNOT_OP) + incom[bc.imm(1)].insert(l); + else if (bc == RBC::GOTO_OP) + incom[bc.imm(0)].insert(l); + } + + // Add fall-through cases + for (auto pc = bytecode.begin(); pc != bytecode.end(); ++pc) { + if (!(*pc).falls()) + continue; + auto next = pc + 1; + auto next_label = next.label(bytecode.begin()); + auto next_merge = incom.find(next_label); + if (next_merge != incom.end()) { + next_merge->second.insert(pc.label(bytecode.begin())); + } + } + + // Create mergepoints + for (auto m : incom) + // The first position must also be considered a mergepoint in case it + // has only one incoming (a jump) + if (m.first == 1 || m.second.size() > 1) + info.mergepoints.insert(m.first); +} + +pir::ClosureVersion* Gnur2Pir::compile(SEXP src, const std::string& name) { + SEXP body = BODY(src); + RBCCode bytecode(body); + SEXP consts = BCODE_CONSTS(body); + auto cnst = [&](int i) { return VECTOR_ELT(consts, i); }; + + std::cout << bytecode; + + CompilerInfo info(src); + + findMerges(bytecode, info); + + std::cout << "Merges at: "; + for (auto& m : info.mergepoints) { + std::cout << m << " "; + } + std::cout << "\n"; + + auto c = m.getOrDeclareGnurClosure(name, src, Context()); + auto v = c->declareVersion(Compiler::defaultContext, true, nullptr); + + Builder insert(v, c->closureEnv()); + + struct Stack { + private: + std::stack stack; + + public: + void push(Value* v) { stack.push(v); } + Value* pop() { + auto v = stack.top(); + stack.pop(); + return v; + } + ~Stack() { assert(stack.empty()); } + }; + Stack stack; + for (auto pc = bytecode.begin(); pc != bytecode.end(); ++pc) { + auto bc = *pc; + switch (bc.id) { + case RBC::GETVAR_OP: { + auto v = insert(new LdVar(cnst(bc.imm(0)), insert.env)); + stack.push( + insert(new Force(v, insert.env, Tombstone::framestate()))); + break; + } + case RBC::RETURN_OP: + insert(new Return(stack.pop())); + break; + default: + std::cerr << "Could not compile " << *pc << "\n"; + return nullptr; + } + } + return v; +} + +} // namespace pir +} // namespace rir diff --git a/rir/src/compiler/gnur2pir/gnur2pir.h b/rir/src/compiler/gnur2pir/gnur2pir.h new file mode 100644 index 000000000..8ef35e21c --- /dev/null +++ b/rir/src/compiler/gnur2pir/gnur2pir.h @@ -0,0 +1,24 @@ +#ifndef GNUR_2_PIR_H +#define GNUR_2_PIR_H + +#include "compiler/compiler.h" +#include "compiler/pir/builder.h" + +#include +#include + +namespace rir { +namespace pir { + +class Gnur2Pir { + Module& m; + + public: + Gnur2Pir(Module& m) : m(m){}; + pir::ClosureVersion* compile(SEXP src, const std::string& name); +}; + +} // namespace pir +} // namespace rir + +#endif diff --git a/rir/src/compiler/log/stream_logger.cpp b/rir/src/compiler/log/stream_logger.cpp index 0a5ffcce3..0e711866f 100644 --- a/rir/src/compiler/log/stream_logger.cpp +++ b/rir/src/compiler/log/stream_logger.cpp @@ -93,7 +93,8 @@ ClosureStreamLogger& StreamLogger::begin(ClosureVersion* cls) { if (options.includes(DebugFlag::PrintEarlyRir)) { logger.preparePrint(); logger.section("Original version"); - cls->owner()->rirFunction()->disassemble(logger.out().out); + cls->owner()->rirFunction( + [&](rir::Function* f) { f->disassemble(logger.out().out); }); logger.out() << "\n"; } diff --git a/rir/src/compiler/native/pir_jit_llvm.h b/rir/src/compiler/native/pir_jit_llvm.h index 6c6fd68c0..5ee96dd37 100644 --- a/rir/src/compiler/native/pir_jit_llvm.h +++ b/rir/src/compiler/native/pir_jit_llvm.h @@ -75,13 +75,7 @@ class PirJitLLVM { static std::string makeName(Code* c) { std::stringstream ss; ss << "rsh_"; - if (auto cls = ClosureVersion::Cast(c)) { - ss << cls->name(); - } else if (auto p = Promise::Cast(c)) { - ss << p->owner->name() << "_" << *p; - } else { - assert(false); - } + c->printName(ss); ss << "." << nModules; return ss.str(); } diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index d1cdfbbe6..fdbdce02e 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -248,8 +248,9 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, auto availableAssumptions = call->inferAvailableAssumptions(); assert(version->context().numMissing() <= availableAssumptions.numMissing()); - cls->rirFunction()->clearDisabledAssumptions( - availableAssumptions); + cls->rirFunction([&](rir::Function* f) { + f->clearDisabledAssumptions(availableAssumptions); + }); // We picked up more assumptions, let's compile a better // version. Maybe we should limit this at some point, to avoid @@ -366,7 +367,9 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (!newAssumptions.isNotObj(i) && newAssumptions.isEager(i)) newAssumptions.setNotObj(i); - cls->rirFunction()->clearDisabledAssumptions(newAssumptions); + cls->rirFunction([&](rir::Function* f) { + f->clearDisabledAssumptions(newAssumptions); + }); auto newVersion = cls->cloneWithAssumptions( version, newAssumptions, [&](ClosureVersion* newCls) { diff --git a/rir/src/compiler/opt/force_dominance.cpp b/rir/src/compiler/opt/force_dominance.cpp index 737e151ae..892817bdf 100644 --- a/rir/src/compiler/opt/force_dominance.cpp +++ b/rir/src/compiler/opt/force_dominance.cpp @@ -299,7 +299,8 @@ bool ForceDominance::apply(Compiler&, ClosureVersion* cls, Code* code, auto pos = split->begin(); MkArg* fixedMkArg = - new MkArg(mkarg->prom(), promRes, mkarg->promEnv()); + new MkArg(mkarg->prom(), promRes, mkarg->promEnv(), + mkarg->originalIdx); pos = split->insert(pos, fixedMkArg); pos++; CastType* upcast = new CastType( diff --git a/rir/src/compiler/opt/inline.cpp b/rir/src/compiler/opt/inline.cpp index e8818a55f..1df803065 100644 --- a/rir/src/compiler/opt/inline.cpp +++ b/rir/src/compiler/opt/inline.cpp @@ -27,11 +27,15 @@ bool Inline::apply(Compiler&, ClosureVersion* cls, Code* code, return false; auto dontInline = [](Closure* cls) { - if (cls->rirFunction()->flags.contains(rir::Function::DisableInline)) - return true; - if (cls->rirFunction()->flags.contains(rir::Function::ForceInline)) - return false; - return cls->rirFunction()->flags.contains(rir::Function::NotInlineable); + bool res = false; + cls->rirFunction([&](rir::Function* f) { + if (f->flags.contains(rir::Function::DisableInline)) + res = true; + if (f->flags.contains(rir::Function::ForceInline)) + res = false; + res = f->flags.contains(rir::Function::NotInlineable); + }); + return res; }; Visitor::run( @@ -208,27 +212,26 @@ bool Inline::apply(Compiler&, ClosureVersion* cls, Code* code, if (inlinee->owner() == cls->owner()) { continue; } else if (weight > Parameter::INLINER_MAX_INLINEE_SIZE) { - if (!inlineeCls->rirFunction()->flags.contains( - rir::Function::ForceInline) && - inlinee->numNonDeoptInstrs() > - Parameter::INLINER_MAX_INLINEE_SIZE * 4) - inlineeCls->rirFunction()->flags.set( - rir::Function::NotInlineable); + inlineeCls->rirFunction([&](rir::Function* f) { + if (!f->flags.contains(rir::Function::ForceInline) && + inlinee->numNonDeoptInstrs() > + Parameter::INLINER_MAX_INLINEE_SIZE * 4) + f->flags.set(rir::Function::NotInlineable); + }); continue; } else { updateAllowInline(inlinee); inlinee->eachPromise( [&](Promise* p) { updateAllowInline(p); }); if (allowInline == SafeToInline::No) { - inlineeCls->rirFunction()->flags.set( - rir::Function::NotInlineable); + inlineeCls->rirFunction([&](rir::Function* f) { + f->flags.set(rir::Function::NotInlineable); + }); continue; } } - if (!inlineeCls->rirFunction()->flags.contains( - rir::Function::ForceInline)) - fuel--; + fuel--; cls->inlinees++; @@ -354,8 +357,9 @@ bool Inline::apply(Compiler&, ClosureVersion* cls, Code* code, if (failedToInline) { delete copy; bb->overrideNext(split); - inlineeCls->rirFunction()->flags.set( - rir::Function::NotInlineable); + inlineeCls->rirFunction([&](rir::Function* f) { + f->flags.set(rir::Function::NotInlineable); + }); } else { anyChange = true; bb->overrideNext(copy); @@ -380,8 +384,8 @@ bool Inline::apply(Compiler&, ClosureVersion* cls, Code* code, mk->updatePromise( cls->promises().at(newPromId[id])); } else { - Promise* clone = - cls->createProm(mk->prom()->rirSrc()); + Promise* clone = cls->createProm( + mk->prom()->expression()); BB* promCopy = BBTransform::clone( mk->prom()->entry, clone, cls); clone->entry = promCopy; diff --git a/rir/src/compiler/opt/inline_force_prom.cpp b/rir/src/compiler/opt/inline_force_prom.cpp index 18c01eb1f..993ed7976 100644 --- a/rir/src/compiler/opt/inline_force_prom.cpp +++ b/rir/src/compiler/opt/inline_force_prom.cpp @@ -40,32 +40,30 @@ bool InlineForcePromises::apply(Compiler&, ClosureVersion* cls, Code* code, if (clsCallee) { - auto functionVersion = clsCallee->rirFunction(); - if (functionVersion->flags.contains( - rir::Function::Flag::DepromiseArgs)) { - - call->eachCallArg([&](InstrArg& v) { - if (auto mkarg = MkArg::Cast(v.val())) { - - anyChange = true; - - auto cast = new CastType( - mkarg, CastType::Kind::Upcast, RType::prom, - PirType::valOrLazy()); - - auto forced = - new Force(cast, mkarg->env(), - Tombstone::framestate()); - v.val() = forced; - ip = bb->insert(ip, cast) + 1; - ip = bb->insert(ip, forced) + 1; - next = ip + 1; - - } - - - }); - } + clsCallee->rirFunction([&](rir::Function* functionVersion) { + if (functionVersion->flags.contains( + rir::Function::Flag::DepromiseArgs)) { + + call->eachCallArg([&](InstrArg& v) { + if (auto mkarg = MkArg::Cast(v.val())) { + + anyChange = true; + + auto cast = new CastType( + mkarg, CastType::Kind::Upcast, + RType::prom, PirType::valOrLazy()); + + auto forced = + new Force(cast, mkarg->env(), + Tombstone::framestate()); + v.val() = forced; + ip = bb->insert(ip, cast) + 1; + ip = bb->insert(ip, forced) + 1; + next = ip + 1; + } + }); + } + }); } } ip = next; diff --git a/rir/src/compiler/opt/promise_splitter.cpp b/rir/src/compiler/opt/promise_splitter.cpp index e1c621959..a65413e49 100644 --- a/rir/src/compiler/opt/promise_splitter.cpp +++ b/rir/src/compiler/opt/promise_splitter.cpp @@ -96,7 +96,8 @@ bool PromiseSplitter::apply(Compiler&, ClosureVersion* cls, Code* code, }); assert(seen); } - auto nmk = new MkArg(mk->prom(), mk->eagerArg(), mk->env()); + auto nmk = new MkArg(mk->prom(), mk->eagerArg(), mk->env(), + mk->originalIdx); auto nct = new CastType(nmk, ct->kind, ct->arg(0).type(), ct->type); pos = bb->insert(pos, nct); bb->insert(pos, nmk); diff --git a/rir/src/compiler/pir/builder.cpp b/rir/src/compiler/pir/builder.cpp index c31c96675..32f3710e3 100644 --- a/rir/src/compiler/pir/builder.cpp +++ b/rir/src/compiler/pir/builder.cpp @@ -110,8 +110,11 @@ Builder::Builder(ClosureVersion* version, Value* closureEnv) std::vector args(closure->nargs()); size_t nargs = version->effectiveNArgs(); - auto depromiseArgs = version->owner()->rirFunction()->flags.contains( - rir::Function::Flag::DepromiseArgs); + auto depromiseArgs = false; + version->owner()->rirFunction([&](rir::Function* f) { + if (f->flags.contains(rir::Function::Flag::DepromiseArgs)) + depromiseArgs = true; + }); for (long i = nargs - 1; i >= 0; --i) { args[i] = this->operator()(new LdArg(i)); @@ -129,10 +132,13 @@ Builder::Builder(ClosureVersion* version, Value* closureEnv) } auto mkenv = new MkEnv(closureEnv, closure->formals().names(), args.data()); - auto rirCode = version->owner()->rirFunction()->body(); - if (rirCode->flags.contains(rir::Code::NeedsFullEnv)) - mkenv->neverStub = true; - mkenv->typeFeedback.srcCode = rirCode; + + version->owner()->rirFunction([&](rir::Function* f) { + auto rirCode = f->body(); + if (rirCode->flags.contains(rir::Code::NeedsFullEnv)) + mkenv->neverStub = true; + mkenv->typeFeedback.srcCode = rirCode; + }); add(mkenv); this->env = mkenv; } diff --git a/rir/src/compiler/pir/closure.cpp b/rir/src/compiler/pir/closure.cpp index bedac76b6..cdc6cf1c1 100644 --- a/rir/src/compiler/pir/closure.cpp +++ b/rir/src/compiler/pir/closure.cpp @@ -1,6 +1,7 @@ #include "closure.h" #include "closure_version.h" #include "env.h" +#include "interpreter/instance.h" #include "runtime/DispatchTable.h" namespace rir { @@ -8,16 +9,18 @@ namespace pir { Closure::Closure(const std::string& name, rir::Function* function, SEXP formals, SEXP srcRef, Context userContext) - : origin_(nullptr), function(function), env(Env::notClosed()), - srcRef_(srcRef), name_(name), formals_(function, formals), - userContext_(userContext) { + : origin_(nullptr), + expression_(src_pool_at(globalContext(), function->body()->src)), + function(function), env(Env::notClosed()), srcRef_(srcRef), name_(name), + formals_(function, formals), userContext_(userContext) { invariant(); } Closure::Closure(const std::string& name, SEXP closure, rir::Function* f, Env* env, Context userContext) - : origin_(closure), function(f), env(env), name_(name), - formals_(f, FORMALS(closure)), userContext_(userContext) { + : origin_(closure), expression_(R_ClosureExpr(closure)), function(f), + env(env), name_(name), formals_(f, FORMALS(closure)), + userContext_(userContext) { static SEXP srcRefSymbol = Rf_install("srcref"); srcRef_ = Rf_getAttrib(closure, srcRefSymbol); @@ -78,5 +81,14 @@ void Closure::print(std::ostream& out, bool tty) const { }); } +size_t Closure::bodySize() const { + if (function) + return function->body()->codeSize; + if (origin_ && TYPEOF(BODY(origin_)) == BCODESXP) + return XLENGTH(BCODE_CODE(BODY(origin_))); + assert(false); + return 0; +} + } // namespace pir } // namespace rir diff --git a/rir/src/compiler/pir/closure.h b/rir/src/compiler/pir/closure.h index 9d7a1307b..b8ba4b938 100644 --- a/rir/src/compiler/pir/closure.h +++ b/rir/src/compiler/pir/closure.h @@ -37,6 +37,7 @@ class Closure { void invariant() const; SEXP origin_; + SEXP expression_; rir::Function* function; Env* env; SEXP srcRef_; @@ -51,17 +52,23 @@ class Closure { return c.smaller(this->userContext_); } + SEXP expression() const { return expression_; } + SEXP rirClosure() const { assert(origin_ && "Inner function does not have a source rir closure"); return origin_; } bool hasOriginClosure() const { return origin_; } - rir::Function* rirFunction() const { return function; } + void rirFunction(const std::function& apply) const { + if (function) + apply(function); + } SEXP srcRef() { return srcRef_; } Env* closureEnv() const { return env; } const std::string& name() const { return name_; } size_t nargs() const { return formals_.nargs(); } + size_t bodySize() const; const FormalArgs& formals() const { return formals_; } void print(std::ostream& out, bool tty) const; diff --git a/rir/src/compiler/pir/closure_version.cpp b/rir/src/compiler/pir/closure_version.cpp index cfc1aadf5..40f1ee4a1 100644 --- a/rir/src/compiler/pir/closure_version.cpp +++ b/rir/src/compiler/pir/closure_version.cpp @@ -74,8 +74,8 @@ void ClosureVersion::printBBGraph(std::ostream& out, out << "}\n"; } -Promise* ClosureVersion::createProm(rir::Code* rirSrc) { - Promise* p = new Promise(this, promises_.size(), rirSrc); +Promise* ClosureVersion::createProm(SEXP expression) { + Promise* p = new Promise(this, promises_.size(), expression); promises_.push_back(p); return p; } @@ -117,6 +117,8 @@ size_t ClosureVersion::effectiveNArgs() const { return owner_->nargs() - optimizationContext_.numMissing(); } +SEXP ClosureVersion::expression() const { return owner()->expression(); } + ClosureVersion::ClosureVersion(Closure* closure, rir::Function* optFunction, bool root, const Context& optimizationContext, const Properties& properties) @@ -130,6 +132,8 @@ ClosureVersion::ClosureVersion(Closure* closure, rir::Function* optFunction, nameSuffix_ = id.str(); } +void ClosureVersion::printName(std::ostream& out) const { out << name(); } + std::ostream& operator<<(std::ostream& out, const ClosureVersion::Property& p) { switch (p) { case ClosureVersion::Property::IsEager: @@ -163,9 +167,5 @@ std::ostream& operator<<(std::ostream& out, return out; } -rir::Code* ClosureVersion::rirSrc() const { - return owner()->rirFunction()->body(); -} - } // namespace pir } // namespace rir diff --git a/rir/src/compiler/pir/closure_version.h b/rir/src/compiler/pir/closure_version.h index f21719168..9a2ae7eb0 100644 --- a/rir/src/compiler/pir/closure_version.h +++ b/rir/src/compiler/pir/closure_version.h @@ -16,8 +16,7 @@ namespace pir { * ClosureVersion * */ -class ClosureVersion - : public CodeImpl { +class ClosureVersion : public Code { public: enum class Property { IsEager, @@ -42,9 +41,8 @@ class ClosureVersion const bool root; - rir::Function* optFunction; - private: + rir::Function* optFunction; Closure* owner_; std::vector promises_; const Context& optimizationContext_; @@ -70,6 +68,7 @@ class ClosureVersion const std::string& name() const { return name_; } const std::string& nameSuffix() const { return nameSuffix_; } + void printName(std::ostream& out) const override final; void print(std::ostream& out, bool tty) const; void print(DebugStyle style, std::ostream& out, bool tty, bool omitDeoptBranches) const; @@ -78,13 +77,21 @@ class ClosureVersion void printGraph(std::ostream& out, bool omitDeoptBranches) const; void printBBGraph(std::ostream& out, bool omitDeoptBranches) const; - Promise* createProm(rir::Code* rirSrc); + Promise* createProm(SEXP expr); Promise* promise(unsigned id) const { return promises_.at(id); } const std::vector& promises() { return promises_; } void erasePromise(unsigned id); + SEXP expression() const override final; + + PirTypeFeedback* pirTypeFeedback() { + if (optFunction) + optFunction->body()->pirTypeFeedback(); + return nullptr; + } + typedef std::function PromiseIterator; void eachPromise(PromiseIterator it) const { for (auto p : promises_) @@ -94,8 +101,6 @@ class ClosureVersion size_t numNonDeoptInstrs() const; - rir::Code* rirSrc() const override final; - friend std::ostream& operator<<(std::ostream& out, const ClosureVersion& e) { out << e.name(); diff --git a/rir/src/compiler/pir/code.h b/rir/src/compiler/pir/code.h index 92e100404..ff1c4e4a0 100644 --- a/rir/src/compiler/pir/code.h +++ b/rir/src/compiler/pir/code.h @@ -1,6 +1,7 @@ #ifndef COMPILER_CODE_H #define COMPILER_CODE_H +#include "R/r_incl.h" #include "pir.h" #include @@ -11,13 +12,6 @@ struct Code; namespace pir { -enum class CodeTag : uint8_t { - ClosureVersion, - Promise, - - Invalid -}; - /* * A piece of code, starting at the BB entry. * @@ -26,36 +20,19 @@ enum class CodeTag : uint8_t { */ class Code { public: - CodeTag tag; BB* entry = nullptr; size_t nextBBId = 0; - explicit Code(CodeTag tag = CodeTag::Invalid) : tag(tag) {} + Code() {} void printCode(std::ostream&, bool tty, bool omitDeoptBranches) const; void printGraphCode(std::ostream&, bool omitDeoptBranches) const; void printBBGraphCode(std::ostream&, bool omitDeoptBranches) const; virtual ~Code(); + virtual void printName(std::ostream& out) const = 0; + virtual SEXP expression() const = 0; size_t numInstrs() const; - - virtual rir::Code* rirSrc() const = 0; -}; - -template -class CodeImpl : public Code { - public: - CodeImpl() : Code(CTAG) {} - static const Base* Cast(const Code* c) { - if (c->tag == CTAG) - return static_cast(c); - return nullptr; - } - static Base* Cast(Code* c) { - if (c->tag == CTAG) - return static_cast(c); - return nullptr; - } }; } // namespace pir diff --git a/rir/src/compiler/pir/instruction.cpp b/rir/src/compiler/pir/instruction.cpp index faeb51fe3..6899c050c 100644 --- a/rir/src/compiler/pir/instruction.cpp +++ b/rir/src/compiler/pir/instruction.cpp @@ -571,10 +571,10 @@ void Branch::printGraphBranches(std::ostream& out, size_t bbId) const { << " [color=red]; // -> BB" << falseBB->id << "\n"; } -MkArg::MkArg(Promise* prom, Value* v, Value* env) +MkArg::MkArg(Promise* prom, Value* v, Value* env, size_t originalIdx) : FixedLenInstructionWithEnvSlot(RType::prom, {{PirType::val()}}, {{v}}, env), - prom_(prom) { + prom_(prom), originalIdx(originalIdx) { assert(eagerArg() == v); assert(!MkArg::Cast(eagerArg()->followCasts())); if (isEager()) { @@ -584,7 +584,7 @@ MkArg::MkArg(Promise* prom, Value* v, Value* env) void MkArg::printArgs(std::ostream& out, bool tty) const { eagerArg()->printRef(out); - out << ", " << *prom(); + out << ", " << prom()->id; if (noReflection) out << " (!refl)"; out << ", "; diff --git a/rir/src/compiler/pir/instruction.h b/rir/src/compiler/pir/instruction.h index 64a315e7e..ad628bbac 100644 --- a/rir/src/compiler/pir/instruction.h +++ b/rir/src/compiler/pir/instruction.h @@ -1137,8 +1137,9 @@ class FLIE(MkArg, 2, Effects::None()) { public: bool noReflection = false; + const size_t originalIdx; - MkArg(Promise* prom, Value* v, Value* env); + MkArg(Promise* prom, Value* v, Value* env, size_t originalIdx); Value* eagerArg() const { return arg(0).val(); } void eagerArg(Value* eager) { diff --git a/rir/src/compiler/pir/module.cpp b/rir/src/compiler/pir/module.cpp index d246fde83..dfdf3ac29 100644 --- a/rir/src/compiler/pir/module.cpp +++ b/rir/src/compiler/pir/module.cpp @@ -11,6 +11,18 @@ void Module::print(std::ostream& out, bool tty) { }); } +Closure* Module::getOrDeclareGnurClosure(const std::string& name, SEXP closure, + Context userContext) { + // For Identification we use the real env, but for optimization we only use + // the real environment if this is not an inner function. When it is an + // inner function, then the env is expected to change over time. + auto env = getEnv(CLOENV(closure)); + auto id = Idx(BODY(closure), env); + if (!closures.count(id)) + closures[id] = new Closure(name, closure, nullptr, env, userContext); + return closures.at(id); +} + Closure* Module::getOrDeclareRirFunction(const std::string& name, rir::Function* f, SEXP formals, SEXP src, Context userContext) { diff --git a/rir/src/compiler/pir/module.h b/rir/src/compiler/pir/module.h index 8981df7e1..0c5479be5 100644 --- a/rir/src/compiler/pir/module.h +++ b/rir/src/compiler/pir/module.h @@ -27,6 +27,8 @@ class Module { Context userContext); Closure* getOrDeclareRirClosure(const std::string& name, SEXP closure, rir::Function* f, Context userContext); + Closure* getOrDeclareGnurClosure(const std::string& name, SEXP closure, + Context userContext); typedef std::function PirClosureIterator; typedef std::function PirClosureVersionIterator; @@ -35,7 +37,7 @@ class Module { ~Module(); private: - typedef std::pair Idx; + typedef std::pair Idx; std::map closures; }; diff --git a/rir/src/compiler/pir/promise.cpp b/rir/src/compiler/pir/promise.cpp index 1efc06abd..931cfc2e3 100644 --- a/rir/src/compiler/pir/promise.cpp +++ b/rir/src/compiler/pir/promise.cpp @@ -1,5 +1,6 @@ #include "promise.h" #include "compiler/pir/bb.h" +#include "compiler/pir/closure_version.h" #include "compiler/pir/instruction.h" #include "compiler/util/visitor.h" #include "interpreter/instance.h" @@ -8,12 +9,12 @@ namespace rir { namespace pir { -Promise::Promise(ClosureVersion* owner, unsigned id, rir::Code* rirSrc) - : id(id), owner(owner), rirSrc_(rirSrc), srcPoolIdx_(rirSrc->src) { - assert(src_pool_at(globalContext(), srcPoolIdx_)); -} +Promise::Promise(ClosureVersion* owner, unsigned id, SEXP expression) + : id(id), owner(owner), expression_(expression) {} -unsigned Promise::srcPoolIdx() const { return srcPoolIdx_; } +unsigned Promise::srcPoolIdx() const { + return src_pool_add(globalContext(), expression()); +} LdFunctionEnv* Promise::env() const { LdFunctionEnv* e = nullptr; @@ -43,5 +44,9 @@ bool Promise::trivial() const { return true; } +void Promise::printName(std::ostream& out) const { + out << owner->name() << "_" << id; +} + } // namespace pir } // namespace rir diff --git a/rir/src/compiler/pir/promise.h b/rir/src/compiler/pir/promise.h index 1a970c663..717076b6e 100644 --- a/rir/src/compiler/pir/promise.h +++ b/rir/src/compiler/pir/promise.h @@ -8,28 +8,25 @@ namespace pir { class LdFunctionEnv; -class Promise : public CodeImpl { +class Promise : public Code { public: const unsigned id; ClosureVersion* owner; - friend std::ostream& operator<<(std::ostream& out, const Promise& p) { - out << "Prom(" << p.id << ")"; - return out; - } - unsigned srcPoolIdx() const; - rir::Code* rirSrc() const override final { return rirSrc_; } LdFunctionEnv* env() const; bool trivial() const; + SEXP expression() const override final { return expression_; } + + void printName(std::ostream& out) const override final; + private: - rir::Code* rirSrc_; - const unsigned srcPoolIdx_; + SEXP expression_; friend class ClosureVersion; - Promise(ClosureVersion* owner, unsigned id, rir::Code* rirSrc); + Promise(ClosureVersion* owner, unsigned id, SEXP expression); }; } // namespace pir diff --git a/rir/src/compiler/pir/type.h b/rir/src/compiler/pir/type.h index 5f3a283b2..aae090449 100644 --- a/rir/src/compiler/pir/type.h +++ b/rir/src/compiler/pir/type.h @@ -796,13 +796,13 @@ inline std::ostream& operator<<(std::ostream& out, PirType t) { else if (t.maybePromiseWrapped()) out << "~"; if (!t.maybeHasAttrs()) { - out << "⁻"; + out << ""; } else { if (!t.maybeNotFastVecelt()) { assert(!t.maybeObj()); - out << "ⁿ"; + out << "_"; } else if (!t.maybeObj()) { - out << "⁺"; + out << "+"; } } diff --git a/rir/src/compiler/rir2pir/rir2pir.cpp b/rir/src/compiler/rir2pir/rir2pir.cpp index c0ce0784c..8a950c8a1 100644 --- a/rir/src/compiler/rir2pir/rir2pir.cpp +++ b/rir/src/compiler/rir2pir/rir2pir.cpp @@ -144,9 +144,9 @@ Rir2Pir::Rir2Pir(Compiler& cmp, ClosureVersion* cls, ClosureStreamLogger& log, const std::list& outerFeedback) : compiler(cmp), cls(cls), log(log), name(name), outerFeedback(outerFeedback) { - if (cls->optFunction && cls->optFunction->body()->pirTypeFeedback()) - this->outerFeedback.push_back( - cls->optFunction->body()->pirTypeFeedback()); + if (cls->pirTypeFeedback()) { + this->outerFeedback.push_back(cls->pirTypeFeedback()); + } } Checkpoint* Rir2Pir::addCheckpoint(rir::Code* srcCode, Opcode* pos, @@ -158,8 +158,9 @@ Checkpoint* Rir2Pir::addCheckpoint(rir::Code* srcCode, Opcode* pos, } Value* Rir2Pir::tryCreateArg(rir::Code* promiseCode, Builder& insert, - bool eager) { - Promise* prom = insert.function->createProm(promiseCode); + bool eager, size_t originalIdx) { + Promise* prom = insert.function->createProm( + src_pool_at(globalContext(), promiseCode->src)); { Builder promiseBuilder(insert.function, prom); if (!tryCompilePromise(promiseCode, promiseBuilder)) { @@ -182,7 +183,7 @@ Value* Rir2Pir::tryCreateArg(rir::Code* promiseCode, Builder& insert, return eagerVal; } - return insert(new MkArg(prom, eagerVal, insert.env)); + return insert(new MkArg(prom, eagerVal, insert.env, originalIdx)); } struct TargetInfo { @@ -566,7 +567,8 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, Value* val = UnboundValue::instance(); if (bc.bc == Opcode::mk_eager_promise_) val = pop(); - Promise* prom = insert.function->createProm(promiseCode); + Promise* prom = insert.function->createProm( + src_pool_at(globalContext(), promiseCode->src)); { Builder promiseBuilder(insert.function, prom); if (!tryCompilePromise(promiseCode, promiseBuilder)) { @@ -581,7 +583,7 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, return false; } } - push(insert(new MkArg(prom, val, env))); + push(insert(new MkArg(prom, val, env, promi))); break; } @@ -689,8 +691,10 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, args[i] = mk->eagerArg(); } else { assert(at(nargs - 1 - i) == args[i]); - args[i] = - tryCreateArg(mk->prom()->rirSrc(), insert, true); + rir::Code* promiseCode = + srcCode->getPromise(mk->originalIdx); + args[i] = tryCreateArg(promiseCode, insert, true, + mk->originalIdx); if (!args[i]) { log.warn("Failed to compile a promise"); return false; @@ -698,8 +702,8 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, // Inlined argument evaluation might have side effects. // Let's have a checkpoint here. This checkpoint needs // to capture the so far evaluated promises. - stack.at(nargs - 1 - i) = - insert(new MkArg(mk->prom(), args[i], mk->env())); + stack.at(nargs - 1 - i) = insert(new MkArg( + mk->prom(), args[i], mk->env(), mk->originalIdx)); addCheckpoint(srcCode, pos, stack, insert); } } @@ -760,7 +764,8 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, Rf_findFun(Rf_install("match.arg"), R_BaseNamespace); if (ti.monomorphic == argmatchFun && matchedArgs.size() == 1) { if (auto mk = MkArg::Cast(matchedArgs[0])) { - auto varName = mk->prom()->rirSrc()->trivialExpr; + auto varName = + srcCode->getPromise(mk->originalIdx)->trivialExpr; if (TYPEOF(varName) == SYMSXP) { auto& formals = cls->owner()->formals(); auto f = std::find(formals.names().begin(), @@ -1197,10 +1202,6 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, return true; } // namespace pir -bool Rir2Pir::tryCompile(Builder& insert) { - return tryCompile(cls->owner()->rirFunction()->body(), insert); -} - bool Rir2Pir::tryCompile(rir::Code* srcCode, Builder& insert) { if (auto mk = MkEnv::Cast(insert.env)) { mk->eachLocalVar([&](SEXP name, Value*, bool) { @@ -1454,15 +1455,13 @@ Value* Rir2Pir::tryTranslate(rir::Code* srcCode, Builder& insert) { } } inner << "@"; - if (srcCode != cls->owner()->rirFunction()->body()) { - size_t i = 0; - for (auto c : insert.function->promises()) { - if (c == insert.code) { - inner << "Prom(" << i << ")"; - break; - } - i++; + size_t i = 0; + for (auto c : insert.function->promises()) { + if (c == insert.code) { + inner << "_" << i; + break; } + i++; } inner << (pos - srcCode->code()); diff --git a/rir/src/compiler/rir2pir/rir2pir.h b/rir/src/compiler/rir2pir/rir2pir.h index 2ac4da663..848b43c80 100644 --- a/rir/src/compiler/rir2pir/rir2pir.h +++ b/rir/src/compiler/rir2pir/rir2pir.h @@ -19,11 +19,14 @@ class Rir2Pir { const std::string& name, const std::list& outerFeedback); - bool tryCompile(Builder& insert) __attribute__((warn_unused_result)); - - Value* tryCreateArg(rir::Code* prom, Builder& insert, bool eager) + // Tries to compile the srcCode. Return value indicates failure. Builder + // has to be discarded, if compilation fails! + bool tryCompile(rir::Code* srcCode, Builder& insert) __attribute__((warn_unused_result)); + Value* tryCreateArg(rir::Code* prom, Builder& insert, bool eager, + size_t originalIdx) __attribute__((warn_unused_result)); + typedef std::unordered_map< Value*, std::tuple> CallTargetFeedback; @@ -32,10 +35,6 @@ class Rir2Pir { Value* tryInlinePromise(rir::Code* srcCode, Builder& insert) __attribute__((warn_unused_result)); - // Tries to compile the srcCode. Return value indicates failure. Builder - // has to be discarded, if compilation fails! - bool tryCompile(rir::Code* srcCode, Builder& insert) - __attribute__((warn_unused_result)); bool tryCompilePromise(rir::Code* prom, Builder& insert) __attribute__((warn_unused_result)); diff --git a/rir/src/compiler/test/PirTests.cpp b/rir/src/compiler/test/PirTests.cpp index 570a6838e..110393b72 100644 --- a/rir/src/compiler/test/PirTests.cpp +++ b/rir/src/compiler/test/PirTests.cpp @@ -374,7 +374,8 @@ class MockBB : public BB { // ~Code wants to delete something entry = new MockBB; } - rir::Code* rirSrc() const override final { return nullptr; } + SEXP expression() const override final { return nullptr; } + void printName(std::ostream&) const override final {} }; public: diff --git a/rir/src/compiler/util/bb_transform.cpp b/rir/src/compiler/util/bb_transform.cpp index 2eed0e21b..dbbebb4da 100644 --- a/rir/src/compiler/util/bb_transform.cpp +++ b/rir/src/compiler/util/bb_transform.cpp @@ -56,7 +56,7 @@ BB* BBTransform::clone(BB* src, Code* target, ClosureVersion* targetClosure) { if (promMap.count(p)) { mk->updatePromise(promMap.at(p)); } else { - auto c = targetClosure->createProm(p->rirSrc()); + auto c = targetClosure->createProm(p->expression()); c->entry = clone(p->entry, c, targetClosure); mk->updatePromise(c); } diff --git a/rir/src/runtime/Code.cpp b/rir/src/runtime/Code.cpp index 2b1e771d7..d37883901 100644 --- a/rir/src/runtime/Code.cpp +++ b/rir/src/runtime/Code.cpp @@ -42,7 +42,7 @@ Code* Code::New(Immediate ast, size_t codeSize, size_t sources, size_t locals, codeSize, sources, locals, bindingCache); } -Code* Code::New(Immediate ast) { return New(ast, 0, 0, 0, 0); } +Code* Code::New(SEXP ast) { return New(ast, 0, 0, 0, 0); } Code::~Code() { // TODO: Not sure if this is actually called diff --git a/rir/src/runtime/Code.h b/rir/src/runtime/Code.h index f6bd3adb7..d289f8e08 100644 --- a/rir/src/runtime/Code.h +++ b/rir/src/runtime/Code.h @@ -70,7 +70,7 @@ struct Code : public RirRuntimeObject { size_t bindingCache); static Code* New(Immediate ast, size_t codeSize, size_t sources, size_t locals, size_t bindingCache); - static Code* New(Immediate ast); + static Code* New(SEXP ast); NativeCode nativeCode; diff --git a/rir/src/utils/FormalArgs.h b/rir/src/utils/FormalArgs.h index c920035d2..c6f54a30a 100644 --- a/rir/src/utils/FormalArgs.h +++ b/rir/src/utils/FormalArgs.h @@ -26,6 +26,11 @@ class FormalArgs { if (it.tag() == R_DotsSymbol) hasDots_ = true; + if (!function) { + defaultArgs_.push_back(*it); + continue; + } + auto arg = function->defaultArg(i); if (*it != R_MissingArg) { assert(arg != nullptr && "Rir compiled function is missing a " From 886e11b399dceb4e1d728eba41b8b844d2dbf6f4 Mon Sep 17 00:00:00 2001 From: oli Date: Fri, 9 Apr 2021 09:39:23 +0000 Subject: [PATCH 2/4] refactoring a bit --- rir/src/compiler/gnur2pir/gnur2pir.cpp | 124 ++++++++++++++++++------- rir/src/compiler/opt/inline.cpp | 5 +- rir/tests/gnur2pir.R | 10 ++ 3 files changed, 106 insertions(+), 33 deletions(-) create mode 100644 rir/tests/gnur2pir.R diff --git a/rir/src/compiler/gnur2pir/gnur2pir.cpp b/rir/src/compiler/gnur2pir/gnur2pir.cpp index 1a5029543..4a791dc4d 100644 --- a/rir/src/compiler/gnur2pir/gnur2pir.cpp +++ b/rir/src/compiler/gnur2pir/gnur2pir.cpp @@ -368,9 +368,26 @@ class RBCCode { } }; +struct Stack { + private: + std::stack stack; + + public: + void push(Value* v) { stack.push(v); } + Value* pop() { + auto v = stack.top(); + stack.pop(); + return v; + } + ~Stack() { assert(stack.empty()); } +}; + struct CompilerInfo { - CompilerInfo(SEXP src) : src(src) {} + CompilerInfo(SEXP src, Stack& stack, Builder& insert) + : src(src), stack(stack), insert(insert) {} SEXP src; + Stack& stack; + Builder& insert; std::unordered_set mergepoints; std::unordered_map jumps; }; @@ -407,57 +424,102 @@ static void findMerges(const RBCCode& bytecode, CompilerInfo& info) { info.mergepoints.insert(m.first); } +struct BCCompiler { + CompilerInfo& cmp; + BCCompiler(CompilerInfo& cmp) : cmp(cmp) {} + + void push(Value* v) { cmp.stack.push(v); } + + Value* pop() { return cmp.stack.pop(); } + + Instruction* insertPush(Instruction* i) { + push(insert(i)); + return i; + } + + Instruction* insert(Instruction* i) { + cmp.insert(i); + return i; + } + + SEXP cnst(int i) { return VECTOR_ELT(BCODE_CONSTS(BODY(cmp.src)), i); } + + Value* env() { return cmp.insert.env; } + + template + void compile(RBC); +}; + +// Start instructions translation + +template <> +void BCCompiler::compile(RBC bc) { + auto v = insert(new LdVar(cnst(bc.imm(0)), env())); + insertPush(new Force(v, env(), Tombstone::framestate())); +} + +template <> +void BCCompiler::compile(RBC bc) { + insert(new Return(pop())); +} + +template <> +void BCCompiler::compile(RBC bc) { + push(insert(new LdConst(cnst(bc.imm(0))))); +} + +template <> +void BCCompiler::compile(RBC bc) { + auto a = pop(); + auto b = pop(); + insertPush(new Add(a, b, env(), bc.imm(0))); +} + +// End instructions translation + pir::ClosureVersion* Gnur2Pir::compile(SEXP src, const std::string& name) { SEXP body = BODY(src); RBCCode bytecode(body); - SEXP consts = BCODE_CONSTS(body); - auto cnst = [&](int i) { return VECTOR_ELT(consts, i); }; std::cout << bytecode; - CompilerInfo info(src); + auto c = m.getOrDeclareGnurClosure(name, src, Context()); + auto v = c->declareVersion(Compiler::defaultContext, true, nullptr); + + Stack stack; + Builder insert(v, c->closureEnv()); + + CompilerInfo info(src, stack, insert); findMerges(bytecode, info); - std::cout << "Merges at: "; + std::cout << "CFG merges at: "; for (auto& m : info.mergepoints) { std::cout << m << " "; } std::cout << "\n"; - auto c = m.getOrDeclareGnurClosure(name, src, Context()); - auto v = c->declareVersion(Compiler::defaultContext, true, nullptr); - - Builder insert(v, c->closureEnv()); + BCCompiler bccompiler(info); - struct Stack { - private: - std::stack stack; +#define SUPPORTED_INSTRUCTIONS(V) \ + V(RBC::GETVAR_OP) \ + V(RBC::RETURN_OP) \ + V(RBC::LDCONST_OP) \ + V(RBC::ADD_OP) - public: - void push(Value* v) { stack.push(v); } - Value* pop() { - auto v = stack.top(); - stack.pop(); - return v; - } - ~Stack() { assert(stack.empty()); } - }; - Stack stack; for (auto pc = bytecode.begin(); pc != bytecode.end(); ++pc) { auto bc = *pc; switch (bc.id) { - case RBC::GETVAR_OP: { - auto v = insert(new LdVar(cnst(bc.imm(0)), insert.env)); - stack.push( - insert(new Force(v, insert.env, Tombstone::framestate()))); - break; - } - case RBC::RETURN_OP: - insert(new Return(stack.pop())); - break; +#define V(BC) \ + case BC: \ + bccompiler.compile(bc); \ + break; + SUPPORTED_INSTRUCTIONS(V) +#undef V + default: std::cerr << "Could not compile " << *pc << "\n"; + assert(false); return nullptr; } } diff --git a/rir/src/compiler/opt/inline.cpp b/rir/src/compiler/opt/inline.cpp index 1df803065..88ac8b86e 100644 --- a/rir/src/compiler/opt/inline.cpp +++ b/rir/src/compiler/opt/inline.cpp @@ -31,9 +31,10 @@ bool Inline::apply(Compiler&, ClosureVersion* cls, Code* code, cls->rirFunction([&](rir::Function* f) { if (f->flags.contains(rir::Function::DisableInline)) res = true; - if (f->flags.contains(rir::Function::ForceInline)) + else if (f->flags.contains(rir::Function::ForceInline)) res = false; - res = f->flags.contains(rir::Function::NotInlineable); + else + res = f->flags.contains(rir::Function::NotInlineable); }); return res; }; diff --git a/rir/tests/gnur2pir.R b/rir/tests/gnur2pir.R new file mode 100644 index 000000000..ea70d96d8 --- /dev/null +++ b/rir/tests/gnur2pir.R @@ -0,0 +1,10 @@ +test <- function(f, arg, expected) { + f = compiler::cmpfun(f) + compiler::disassemble(f) + .Call('gnur2pir', f) + stopifnot(identical(f(arg), expected)) +} + +test(function(a) 1,, 1) +test(function(a) a, 1, 1) +test(function(a) a+1, 1, 2) From 62720dfdc5a09cf9280dbfeac5c74b320d0df204 Mon Sep 17 00:00:00 2001 From: oli Date: Fri, 23 Apr 2021 12:52:49 +0000 Subject: [PATCH 3/4] cppcheck --- rir/src/compiler/gnur2pir/gnur2pir.cpp | 12 ++++++------ rir/src/compiler/gnur2pir/gnur2pir.h | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rir/src/compiler/gnur2pir/gnur2pir.cpp b/rir/src/compiler/gnur2pir/gnur2pir.cpp index 4a791dc4d..2f41da46c 100644 --- a/rir/src/compiler/gnur2pir/gnur2pir.cpp +++ b/rir/src/compiler/gnur2pir/gnur2pir.cpp @@ -426,7 +426,7 @@ static void findMerges(const RBCCode& bytecode, CompilerInfo& info) { struct BCCompiler { CompilerInfo& cmp; - BCCompiler(CompilerInfo& cmp) : cmp(cmp) {} + explicit BCCompiler(CompilerInfo& cmp) : cmp(cmp) {} void push(Value* v) { cmp.stack.push(v); } @@ -447,29 +447,29 @@ struct BCCompiler { Value* env() { return cmp.insert.env; } template - void compile(RBC); + void compile(const RBC&); }; // Start instructions translation template <> -void BCCompiler::compile(RBC bc) { +void BCCompiler::compile(const RBC& bc) { auto v = insert(new LdVar(cnst(bc.imm(0)), env())); insertPush(new Force(v, env(), Tombstone::framestate())); } template <> -void BCCompiler::compile(RBC bc) { +void BCCompiler::compile(const RBC& bc) { insert(new Return(pop())); } template <> -void BCCompiler::compile(RBC bc) { +void BCCompiler::compile(const RBC& bc) { push(insert(new LdConst(cnst(bc.imm(0))))); } template <> -void BCCompiler::compile(RBC bc) { +void BCCompiler::compile(const RBC& bc) { auto a = pop(); auto b = pop(); insertPush(new Add(a, b, env(), bc.imm(0))); diff --git a/rir/src/compiler/gnur2pir/gnur2pir.h b/rir/src/compiler/gnur2pir/gnur2pir.h index 8ef35e21c..3509a496e 100644 --- a/rir/src/compiler/gnur2pir/gnur2pir.h +++ b/rir/src/compiler/gnur2pir/gnur2pir.h @@ -14,7 +14,7 @@ class Gnur2Pir { Module& m; public: - Gnur2Pir(Module& m) : m(m){}; + explicit Gnur2Pir(Module& m) : m(m){}; pir::ClosureVersion* compile(SEXP src, const std::string& name); }; From d7d9f8310aae7fd9f96a4cfe93d668f01bcddadb Mon Sep 17 00:00:00 2001 From: oli Date: Thu, 29 Apr 2021 15:07:15 +0000 Subject: [PATCH 4/4] add some more bcs --- rir/src/compiler/gnur2pir/gnur2pir.cpp | 265 +++++++++++++++++++++++-- rir/tests/gnur2pir.R | 4 + 2 files changed, 252 insertions(+), 17 deletions(-) diff --git a/rir/src/compiler/gnur2pir/gnur2pir.cpp b/rir/src/compiler/gnur2pir/gnur2pir.cpp index 2f41da46c..b9f642028 100644 --- a/rir/src/compiler/gnur2pir/gnur2pir.cpp +++ b/rir/src/compiler/gnur2pir/gnur2pir.cpp @@ -18,6 +18,7 @@ #include "simple_instruction_list.h" #include "utils/FormalArgs.h" +#include #include #include #include @@ -333,6 +334,10 @@ class RBCCode { bool operator==(const Iterator& other) const { return pos == other.pos; } + size_t operator-(const Iterator& other) const { + assert(pos >= other.pos); + return pos - other.pos; + } size_t label(const Iterator& begin) const { return pos - begin.pos + 1; @@ -370,26 +375,67 @@ class RBCCode { struct Stack { private: - std::stack stack; + std::vector stack; public: - void push(Value* v) { stack.push(v); } + void push(Value* v) { stack.push_back(v); } Value* pop() { - auto v = stack.top(); - stack.pop(); + auto v = stack.back(); + stack.pop_back(); return v; } - ~Stack() { assert(stack.empty()); } + bool isEmpty() const { return stack.empty(); } + friend struct CompilerInfo; }; struct CompilerInfo { CompilerInfo(SEXP src, Stack& stack, Builder& insert) - : src(src), stack(stack), insert(insert) {} + : src(src), stack(stack), insert(insert) { + basicBlocks[1] = insert.getCurrentBB(); + } + BB* getBB(size_t pos) { + auto f = basicBlocks.find(pos); + if (f != basicBlocks.end()) + return f->second; + return basicBlocks + .emplace(pos, new BB(insert.code, insert.code->nextBBId++)) + .first->second; + } SEXP src; Stack& stack; Builder& insert; std::unordered_set mergepoints; + std::unordered_map bbState; + std::unordered_map basicBlocks; std::unordered_map jumps; + + void merge(std::initializer_list trgs) { + for (auto trg : trgs) { + auto bb = getBB(trg); + if (!mergepoints.count(trg)) { + assert(!bbState.count(bb)); + bbState[bb] = stack; + continue; + } + + auto state = bbState.find(bb); + if (state == bbState.end()) { + Stack trgStack; + for (auto v : stack.stack) { + trgStack.push(*bb->insert( + bb->end(), new Phi({{insert.getCurrentBB(), v}}))); + } + bbState.emplace(bb, trgStack); + } else { + Stack& trgStack = state->second; + for (size_t pos = 0; pos < trgStack.stack.size(); ++pos) { + Phi::Cast(trgStack.stack[pos]) + ->addInput(insert.getCurrentBB(), stack.stack[pos]); + } + } + } + stack.stack.clear(); + } }; static void findMerges(const RBCCode& bytecode, CompilerInfo& info) { @@ -430,6 +476,7 @@ struct BCCompiler { void push(Value* v) { cmp.stack.push(v); } + bool stackEmpty() { return cmp.stack.isEmpty(); } Value* pop() { return cmp.stack.pop(); } Instruction* insertPush(Instruction* i) { @@ -447,34 +494,74 @@ struct BCCompiler { Value* env() { return cmp.insert.env; } template - void compile(const RBC&); + void compile(const RBC&, size_t pos); }; // Start instructions translation template <> -void BCCompiler::compile(const RBC& bc) { +void BCCompiler::compile(const RBC& bc, size_t pos) { auto v = insert(new LdVar(cnst(bc.imm(0)), env())); insertPush(new Force(v, env(), Tombstone::framestate())); } template <> -void BCCompiler::compile(const RBC& bc) { +void BCCompiler::compile(const RBC& bc, size_t pos) { insert(new Return(pop())); } template <> -void BCCompiler::compile(const RBC& bc) { +void BCCompiler::compile(const RBC& bc, size_t pos) { push(insert(new LdConst(cnst(bc.imm(0))))); } template <> -void BCCompiler::compile(const RBC& bc) { +void BCCompiler::compile(const RBC& bc, size_t pos) { auto a = pop(); auto b = pop(); insertPush(new Add(a, b, env(), bc.imm(0))); } +template <> +void BCCompiler::compile(const RBC& bc, size_t pos) { + auto t = insert(new CheckTrueFalse(pop())); + auto label = bc.imm(1); + insert(new Branch(t)); + auto fallPos = pos + bc.imm() + 1; + auto fall = cmp.getBB(fallPos); + auto trg = cmp.getBB(label); + cmp.insert.setBranch(fall, trg); + cmp.merge({fallPos, (size_t)label}); +} + +template <> +void BCCompiler::compile(const RBC& bc, size_t pos) { + auto label = bc.imm(0); + auto trg = cmp.getBB(label); + cmp.insert.setNext(trg); + cmp.merge({(size_t)label}); +} + +template <> +void BCCompiler::compile(const RBC& bc, size_t pos) { + push(Nil::instance()); +} + +template <> +void BCCompiler::compile(const RBC& bc, size_t pos) { + insert(new Invisible()); +} + +template <> +void BCCompiler::compile(const RBC& bc, size_t pos) { + insert(new Invisible()); +} + +template <> +void BCCompiler::compile(const RBC& bc, size_t pos) { + insert(new Invisible()); +} + // End instructions translation pir::ClosureVersion* Gnur2Pir::compile(SEXP src, const std::string& name) { @@ -502,25 +589,169 @@ pir::ClosureVersion* Gnur2Pir::compile(SEXP src, const std::string& name) { BCCompiler bccompiler(info); #define SUPPORTED_INSTRUCTIONS(V) \ + V(RBC::GOTO_OP) \ V(RBC::GETVAR_OP) \ V(RBC::RETURN_OP) \ V(RBC::LDCONST_OP) \ - V(RBC::ADD_OP) + V(RBC::ADD_OP) \ + V(RBC::LDNULL_OP) \ + V(RBC::INVISIBLE_OP) \ + V(RBC::INCLNK_OP) \ + V(RBC::DECLNK_OP) \ + V(RBC::BRIFNOT_OP) + +#define UNSUPPORTED_INSTRUCTIONS(V) \ + V(RBC::BCMISMATCH_OP) \ + V(RBC::POP_OP) \ + V(RBC::DUP_OP) \ + V(RBC::PRINTVALUE_OP) \ + V(RBC::STARTLOOPCNTXT_OP) \ + V(RBC::ENDLOOPCNTXT_OP) \ + V(RBC::DOLOOPNEXT_OP) \ + V(RBC::DOLOOPBREAK_OP) \ + V(RBC::STARTFOR_OP) \ + V(RBC::STEPFOR_OP) \ + V(RBC::ENDFOR_OP) \ + V(RBC::SETLOOPVAL_OP) \ + V(RBC::LDTRUE_OP) \ + V(RBC::LDFALSE_OP) \ + V(RBC::DDVAL_OP) \ + V(RBC::SETVAR_OP) \ + V(RBC::GETFUN_OP) \ + V(RBC::GETGLOBFUN_OP) \ + V(RBC::GETSYMFUN_OP) \ + V(RBC::GETBUILTIN_OP) \ + V(RBC::GETINTLBUILTIN_OP) \ + V(RBC::CHECKFUN_OP) \ + V(RBC::MAKEPROM_OP) \ + V(RBC::DOMISSING_OP) \ + V(RBC::SETTAG_OP) \ + V(RBC::DODOTS_OP) \ + V(RBC::PUSHARG_OP) \ + V(RBC::PUSHCONSTARG_OP) \ + V(RBC::PUSHNULLARG_OP) \ + V(RBC::PUSHTRUEARG_OP) \ + V(RBC::PUSHFALSEARG_OP) \ + V(RBC::CALL_OP) \ + V(RBC::CALLBUILTIN_OP) \ + V(RBC::CALLSPECIAL_OP) \ + V(RBC::MAKECLOSURE_OP) \ + V(RBC::UMINUS_OP) \ + V(RBC::UPLUS_OP) \ + V(RBC::SUB_OP) \ + V(RBC::MUL_OP) \ + V(RBC::DIV_OP) \ + V(RBC::EXPT_OP) \ + V(RBC::SQRT_OP) \ + V(RBC::EXP_OP) \ + V(RBC::EQ_OP) \ + V(RBC::NE_OP) \ + V(RBC::LT_OP) \ + V(RBC::LE_OP) \ + V(RBC::GE_OP) \ + V(RBC::GT_OP) \ + V(RBC::AND_OP) \ + V(RBC::OR_OP) \ + V(RBC::NOT_OP) \ + V(RBC::DOTSERR_OP) \ + V(RBC::STARTASSIGN_OP) \ + V(RBC::ENDASSIGN_OP) \ + V(RBC::STARTSUBSET_OP) \ + V(RBC::DFLTSUBSET_OP) \ + V(RBC::STARTSUBASSIGN_OP) \ + V(RBC::DFLTSUBASSIGN_OP) \ + V(RBC::STARTC_OP) \ + V(RBC::DFLTC_OP) \ + V(RBC::STARTSUBSET2_OP) \ + V(RBC::DFLTSUBSET2_OP) \ + V(RBC::STARTSUBASSIGN2_OP) \ + V(RBC::DFLTSUBASSIGN2_OP) \ + V(RBC::DOLLAR_OP) \ + V(RBC::DOLLARGETS_OP) \ + V(RBC::ISNULL_OP) \ + V(RBC::ISLOGICAL_OP) \ + V(RBC::ISINTEGER_OP) \ + V(RBC::ISDOUBLE_OP) \ + V(RBC::ISCOMPLEX_OP) \ + V(RBC::ISCHARACTER_OP) \ + V(RBC::ISSYMBOL_OP) \ + V(RBC::ISOBJECT_OP) \ + V(RBC::ISNUMERIC_OP) \ + V(RBC::VECSUBSET_OP) \ + V(RBC::MATSUBSET_OP) \ + V(RBC::VECSUBASSIGN_OP) \ + V(RBC::MATSUBASSIGN_OP) \ + V(RBC::AND1ST_OP) \ + V(RBC::AND2ND_OP) \ + V(RBC::OR1ST_OP) \ + V(RBC::OR2ND_OP) \ + V(RBC::GETVAR_MISSOK_OP) \ + V(RBC::DDVAL_MISSOK_OP) \ + V(RBC::VISIBLE_OP) \ + V(RBC::SETVAR2_OP) \ + V(RBC::STARTASSIGN2_OP) \ + V(RBC::ENDASSIGN2_OP) \ + V(RBC::SETTER_CALL_OP) \ + V(RBC::GETTER_CALL_OP) \ + V(RBC::SWAP_OP) \ + V(RBC::DUP2ND_OP) \ + V(RBC::SWITCH_OP) \ + V(RBC::RETURNJMP_OP) \ + V(RBC::STARTSUBSET_N_OP) \ + V(RBC::STARTSUBASSIGN_N_OP) \ + V(RBC::VECSUBSET2_OP) \ + V(RBC::MATSUBSET2_OP) \ + V(RBC::VECSUBASSIGN2_OP) \ + V(RBC::MATSUBASSIGN2_OP) \ + V(RBC::STARTSUBSET2_N_OP) \ + V(RBC::STARTSUBASSIGN2_N_OP) \ + V(RBC::SUBSET_N_OP) \ + V(RBC::SUBSET2_N_OP) \ + V(RBC::SUBASSIGN_N_OP) \ + V(RBC::SUBASSIGN2_N_OP) \ + V(RBC::LOG_OP) \ + V(RBC::LOGBASE_OP) \ + V(RBC::MATH1_OP) \ + V(RBC::DOTCALL_OP) \ + V(RBC::COLON_OP) \ + V(RBC::SEQALONG_OP) \ + V(RBC::SEQLEN_OP) \ + V(RBC::BASEGUARD_OP) \ + V(RBC::DECLNK_N_OP) for (auto pc = bytecode.begin(); pc != bytecode.end(); ++pc) { auto bc = *pc; + auto pos = (pc - bytecode.begin()) + 1; + auto hasBB = info.basicBlocks.find(pos); + if (hasBB != info.basicBlocks.end()) { + info.insert.reenterBB(hasBB->second); + auto state = info.bbState.find(hasBB->second); + if (state != info.bbState.end()) + info.stack = state->second; + } + switch (bc.id) { #define V(BC) \ case BC: \ - bccompiler.compile(bc); \ + bccompiler.compile(bc, pos); \ break; SUPPORTED_INSTRUCTIONS(V) #undef V - default: - std::cerr << "Could not compile " << *pc << "\n"; - assert(false); - return nullptr; +#define V(BC) \ + case BC: \ + std::cerr << "Could not compile " << *pc << "\n"; \ + assert(false); \ + return nullptr; + UNSUPPORTED_INSTRUCTIONS(V) +#undef V + } + + auto next = pos + 1 + bc.imm(); + if (bc.falls() && info.mergepoints.count(next) && + !info.stack.isEmpty()) { + info.merge({next}); + info.insert.getCurrentBB()->setNext(info.getBB(next)); } } return v; diff --git a/rir/tests/gnur2pir.R b/rir/tests/gnur2pir.R index ea70d96d8..dbf0c2cdf 100644 --- a/rir/tests/gnur2pir.R +++ b/rir/tests/gnur2pir.R @@ -8,3 +8,7 @@ test <- function(f, arg, expected) { test(function(a) 1,, 1) test(function(a) a, 1, 1) test(function(a) a+1, 1, 2) +test(function(a) if (a) 1, TRUE, 1) +test(function(a) if (a) 1, FALSE, NULL) +test(function(a) 1 + if (a) 1, T, 2) +test(function(a) 2 - 1 + if (a) 1, T, 2)