From 4647df305144a0d559f059f45769536e32f1d8c4 Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Mon, 2 May 2022 14:14:32 +0000 Subject: [PATCH 01/12] reduce version cloning --- .vscode/settings.json | 4 +- rir/src/api.cpp | 19 ++++---- rir/src/compiler/compiler.cpp | 13 +++++- .../compiler/native/lower_function_llvm.cpp | 3 ++ rir/src/compiler/opt/eager_calls.cpp | 43 +++++++++++++++++-- rir/src/compiler/opt/match_call_args.cpp | 13 ++++++ rir/src/compiler/pir/closure.cpp | 19 ++++++++ rir/src/compiler/pir/closure.h | 6 +++ rir/src/compiler/pir/closure_version.h | 6 ++- rir/src/compiler/pir/instruction.cpp | 13 ++++++ rir/src/compiler/pir/instruction.h | 4 ++ rir/src/compiler/rir2pir/rir2pir.cpp | 5 +++ 12 files changed, 133 insertions(+), 15 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 3c989f58d..7a261015a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -112,6 +112,8 @@ "cfenv": "cpp", "csignal": "cpp", "__functional_base_03": "cpp", - "__memory": "cpp" + "__memory": "cpp", + "__bits": "cpp", + "__availability": "cpp" } } diff --git a/rir/src/api.cpp b/rir/src/api.cpp index 132792164..7eb0d89ba 100644 --- a/rir/src/api.cpp +++ b/rir/src/api.cpp @@ -313,22 +313,23 @@ SEXP pirCompile(SEXP what, const Context& assumptions, const std::string& name, return; rir::Function* done = nullptr; - auto apply = [&](SEXP body, pir::ClosureVersion* c) { - auto fun = backend.getOrCompile(c); + auto apply = [&](SEXP body, pir::ClosureVersion* cv) { + auto fun = backend.getOrCompile(cv); Protect p(fun->container()); DispatchTable::unpack(body)->insert(fun); if (body == BODY(what)) done = fun; }; - m->eachPirClosureVersion([&](pir::ClosureVersion* c) { - if (c->owner()->hasOriginClosure()) { - auto cls = c->owner()->rirClosure(); + + m->eachPirClosureVersion([&](pir::ClosureVersion* eachVersion) { + if (eachVersion->owner()->hasOriginClosure()) { + auto cls = eachVersion->owner()->rirClosure(); auto body = BODY(cls); auto dt = DispatchTable::unpack(body); - if (dt->contains(c->context())) { - auto other = dt->dispatch(c->context()); + if (dt->contains(eachVersion->context())) { + auto other = dt->dispatch(eachVersion->context()); assert(other != dt->baseline()); - assert(other->context() == c->context()); + assert(other->context() == eachVersion->context()); if (other->body()->isCompiled()) return; } @@ -336,7 +337,7 @@ SEXP pirCompile(SEXP what, const Context& assumptions, const std::string& name, // they have incomplete type-feedback. if (dt->size() == 1 && dt->baseline()->invocationCount() < 2) return; - apply(body, c); + apply(body, eachVersion); } }); if (!done) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index 8d5bd04f1..25977a6e6 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -243,6 +243,11 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { std::unordered_map> reachable; bool changed = true; + m->eachPirClosure([&](Closure* c) { + c->hasBeenCloned = false; + c->eachVersion([&](ClosureVersion* v) { v->staticCallRefCount = 0; }); + }); + auto found = [&](ClosureVersion* v) { if (!v) return; @@ -283,7 +288,12 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { i->printRecursive(msg, 2); log.warn(msg.str()); } - found(call->tryDispatch()); + + auto dispatched = call->tryDispatch(); + found(dispatched); + call->lastSeen = dispatched; + dispatched->staticCallRefCount++; + found(call->tryOptimisticDispatch()); found(call->hint); } else if (auto call = CallInstruction::CastCall(i)) { @@ -310,6 +320,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { m->eachPirClosure([&](Closure* c) { const auto& reachableVersions = reachable[c]; c->eachVersion([&](ClosureVersion* v) { + // assert(c->getVersion(v->context()) == v); if (!reachableVersions.count(v->context())) { toErase.push_back({v->owner(), v->context()}); log.close(v); diff --git a/rir/src/compiler/native/lower_function_llvm.cpp b/rir/src/compiler/native/lower_function_llvm.cpp index 03038c143..1fbc58994 100644 --- a/rir/src/compiler/native/lower_function_llvm.cpp +++ b/rir/src/compiler/native/lower_function_llvm.cpp @@ -3388,6 +3388,9 @@ void LowerFunctionLLVM::compile() { if (calli->isReordered()) callId = pushArgReordering(calli->getArgOrderOrig()); + if (!target) + assert(target && "target is null!"); + if (!target->owner()->hasOriginClosure()) { setVal( i, withCallFrame(args, [&]() -> llvm::Value* { diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index 63e2d8051..46bceac39 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -287,8 +287,10 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, // version. Maybe we should limit this at some point, to avoid // version explosion. if (availableAssumptions.isImproving(version)) { - auto newVersion = target->cloneWithAssumptions( - version, availableAssumptions, + + ClosureVersion* newVersion; + + auto updateVersionWithNewAssumptions = [&](ClosureVersion* newCls) { Visitor::run(newCls->entry, [&](Instruction* i) { if (auto f = Force::Cast(i)) { @@ -308,7 +310,42 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, } } }); - }); + }; + + if (version != call->lastSeen) { + version->staticCallRefCount++; + } + + call->lastSeen = nullptr; + if (version->owner()->hasBeenCloned || + version->staticCallRefCount > 1) { + + newVersion = target->cloneWithAssumptions( + version, availableAssumptions, + [&](ClosureVersion* newCls) { + updateVersionWithNewAssumptions(newCls); + }); + + if (newVersion != version) { + version->owner()->hasBeenCloned = true; + + // newVersion->staticCallRefCount++; + // call->lastSeen = newVersion; + // version->staticCallRefCount--; + } + + } else { + + newVersion = target->replaceWithAssumptions( + version, availableAssumptions, + updateVersionWithNewAssumptions); + + call->lastSeen = newVersion; + if (newVersion->staticCallRefCount == 0) { + newVersion->staticCallRefCount = 1; + } + } + call->hint = newVersion; assert(call->tryDispatch() == newVersion); ip = next; diff --git a/rir/src/compiler/opt/match_call_args.cpp b/rir/src/compiler/opt/match_call_args.cpp index 3c9057b03..c9374761e 100644 --- a/rir/src/compiler/opt/match_call_args.cpp +++ b/rir/src/compiler/opt/match_call_args.cpp @@ -242,23 +242,36 @@ bool MatchCallArgs::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (auto c = call) { assert(!usemethodTarget); auto cls = c->cls()->followCastsAndForce(); + target->staticCallRefCount++; auto nc = new StaticCall( c->env(), target->owner(), asmpt, matchedArgs, argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); + + nc->lastSeen = target; + (*ip)->replaceUsesAndSwapWith(nc, ip); } else if (auto c = namedCall) { assert(!usemethodTarget); auto cls = c->cls()->followCastsAndForce(); + target->staticCallRefCount++; auto nc = new StaticCall( c->env(), target->owner(), asmpt, matchedArgs, argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); + + nc->lastSeen = target; + (*ip)->replaceUsesAndSwapWith(nc, ip); } else if (auto c = staticCall) { assert(usemethodTarget); auto cls = cmp.module->c(usemethodTarget); + target->staticCallRefCount++; + auto nc = new StaticCall( c->env(), target->owner(), asmpt, matchedArgs, argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); + + nc->lastSeen = target; + (*ip)->replaceUsesAndSwapWith(nc, ip); } else { assert(false); diff --git a/rir/src/compiler/pir/closure.cpp b/rir/src/compiler/pir/closure.cpp index cb9e3df6f..988531403 100644 --- a/rir/src/compiler/pir/closure.cpp +++ b/rir/src/compiler/pir/closure.cpp @@ -54,6 +54,25 @@ ClosureVersion* Closure::cloneWithAssumptions(ClosureVersion* version, return copy; } +ClosureVersion* Closure::replaceWithAssumptions(ClosureVersion* version, + const Context& asmpt, + const MaybeClsVersion& change) { + assert(versions.count(version->context()) > 0); + + auto newCtx = version->context() | asmpt; + if (versions.count(newCtx)) { + return versions.at(newCtx); + } + + versions.erase(version->context()); + + version->setContext(newCtx); + + versions[newCtx] = version; + change(version); + return version; +} + ClosureVersion* Closure::findCompatibleVersion(const Context& ctx) const { // ordered by number of assumptions for (auto& candidate : versions) { diff --git a/rir/src/compiler/pir/closure.h b/rir/src/compiler/pir/closure.h index e07df5061..5ea0bdea7 100644 --- a/rir/src/compiler/pir/closure.h +++ b/rir/src/compiler/pir/closure.h @@ -51,6 +51,8 @@ class Closure { Context userContext_; public: + bool hasBeenCloned = false; + bool matchesUserContext(Context c) const { return c.smaller(this->userContext_); } @@ -90,6 +92,10 @@ class Closure { const Context& asmpt, const MaybeClsVersion& change); + ClosureVersion* replaceWithAssumptions(ClosureVersion* cls, + const Context& asmpt, + const MaybeClsVersion& change); + typedef std::function ClosureVersionIterator; void eachVersion(ClosureVersionIterator it) const; diff --git a/rir/src/compiler/pir/closure_version.h b/rir/src/compiler/pir/closure_version.h index 7421dc226..72efd85c9 100644 --- a/rir/src/compiler/pir/closure_version.h +++ b/rir/src/compiler/pir/closure_version.h @@ -42,11 +42,12 @@ class ClosureVersion : public Code { const bool root; rir::Function* optFunction; + size_t staticCallRefCount = 0; private: Closure* owner_; std::vector promises_; - const Context optimizationContext_; + Context optimizationContext_; std::string name_; std::string nameSuffix_; @@ -62,6 +63,9 @@ class ClosureVersion : public Code { ClosureVersion* clone(const Context& newContext); const Context& context() const { return optimizationContext_; } + void setContext(const Context& newContext) { + optimizationContext_ = newContext; + } Properties properties; diff --git a/rir/src/compiler/pir/instruction.cpp b/rir/src/compiler/pir/instruction.cpp index ab2235c17..7fa884d12 100644 --- a/rir/src/compiler/pir/instruction.cpp +++ b/rir/src/compiler/pir/instruction.cpp @@ -1275,6 +1275,19 @@ StaticCall::StaticCall(Value* callerEnv, Closure* cls, Context givenContext, assert(tryDispatch()); } +Instruction* StaticCall::clone() const { + auto r = InstructionImplementation::clone(); + + auto sc = StaticCall::Cast(r); + sc->lastSeen = nullptr; + auto target = sc->tryDispatch(); + if (target) { + sc->lastSeen = target; + target->staticCallRefCount++; + } + return sc; +} + PirType StaticCall::inferType(const GetType& getType) const { auto t = PirType::bottom(); if (auto v = tryDispatch()) { diff --git a/rir/src/compiler/pir/instruction.h b/rir/src/compiler/pir/instruction.h index c321b96e9..29bee9368 100644 --- a/rir/src/compiler/pir/instruction.h +++ b/rir/src/compiler/pir/instruction.h @@ -62,6 +62,7 @@ namespace pir { class BB; class Closure; + class Phi; struct InstrArg { @@ -2268,6 +2269,7 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { const ArglistOrder::CallArglistOrder& argOrderOrig, Value* fs, unsigned srcIdx, Value* runtimeClosure = Tombstone::closure()); + ClosureVersion* lastSeen = nullptr; Context givenContext; ClosureVersion* hint = nullptr; @@ -2279,6 +2281,8 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { size_t nCallArgs() const override { return nargs() - 3; } + Instruction* clone() const override; + void eachNamedCallArg(const NamedArgumentValueIterator& it) const override { for (size_t i = 0; i < nCallArgs(); ++i) it(R_NilValue, arg(i + 2).val()); diff --git a/rir/src/compiler/rir2pir/rir2pir.cpp b/rir/src/compiler/rir2pir/rir2pir.cpp index 508213071..1a930ad6f 100644 --- a/rir/src/compiler/rir2pir/rir2pir.cpp +++ b/rir/src/compiler/rir2pir/rir2pir.cpp @@ -931,13 +931,18 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, assert(!inlining()); auto fs = insert.registerFrameState(srcCode, nextPos, stack, inPromise()); + f->staticCallRefCount++; auto cl = insert( new StaticCall(insert.env, f->owner(), given, matchedArgs, std::move(argOrderOrig), fs, ast, f->owner()->closureEnv() == Env::notClosed() ? guardedCallee : Tombstone::closure())); + + cl->lastSeen = f; + cl->effects.set(Effect::DependsOnAssume); + push(cl); auto innerc = MkCls::Cast(guardedCallee->followCastsAndForce()); From 4d8eaf7c9d4f3e1d38d1f2edc161e4758ac6bb42 Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Tue, 3 May 2022 08:09:51 +0000 Subject: [PATCH 02/12] try fix for feature1 --- rir/src/compiler/opt/hoist_instruction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rir/src/compiler/opt/hoist_instruction.cpp b/rir/src/compiler/opt/hoist_instruction.cpp index d99666679..80dda7ac2 100644 --- a/rir/src/compiler/opt/hoist_instruction.cpp +++ b/rir/src/compiler/opt/hoist_instruction.cpp @@ -284,7 +284,7 @@ bool HoistInstruction::apply(Compiler& cmp, ClosureVersion* cls, Code* code, } if ((*it1)->hasObservableEffects()) break; - if (it1 == bb2->end()) + if (it1 == bb1->end()) break; it1++; } From b93f2b2e411e719ef42722b1e82ef2994552e7a3 Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Tue, 3 May 2022 09:20:21 +0000 Subject: [PATCH 03/12] added isClone flag in closure version --- rir/src/compiler/compiler.cpp | 54 ++++++++++++++++++++++---- rir/src/compiler/opt/eager_calls.cpp | 11 +++--- rir/src/compiler/pir/closure.h | 1 - rir/src/compiler/pir/closure_version.h | 3 ++ 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index 25977a6e6..46785e613 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -243,10 +243,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { std::unordered_map> reachable; bool changed = true; - m->eachPirClosure([&](Closure* c) { - c->hasBeenCloned = false; - c->eachVersion([&](ClosureVersion* v) { v->staticCallRefCount = 0; }); - }); + auto found = [&](ClosureVersion* v) { if (!v) @@ -289,11 +286,8 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { log.warn(msg.str()); } - auto dispatched = call->tryDispatch(); - found(dispatched); - call->lastSeen = dispatched; - dispatched->staticCallRefCount++; + found(call->tryDispatch()); found(call->tryOptimisticDispatch()); found(call->hint); } else if (auto call = CallInstruction::CastCall(i)) { @@ -331,6 +325,50 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { for (auto e : toErase) e.first->erase(e.second); + + + + // reset refCount state + m->eachPirClosure([&](Closure* c) { + c->eachVersion([&](ClosureVersion* v) { + v->staticCallRefCount = 0; + v->isClone = false; + }); + }); + + m->eachPirClosure([&](Closure* c) { + + c->eachVersion([&](ClosureVersion* v) { + + auto check = [&](Instruction* i) { + if (auto call = StaticCall::Cast(i)) { + + auto dispatched = call->tryDispatch(); + call->lastSeen = dispatched; + dispatched->staticCallRefCount++; + } + // else if (auto call = CallInstruction::CastCall(i)) { + // if (auto cls = call->tryGetCls()) + // found(call->tryDispatch(cls)); + // } else { + // i->eachArg([&](Value* v) { + // if (auto mk = MkCls::Cast(i)) { + // if (mk->tryGetCls()) + // mk->tryGetCls()->eachVersion(found); + // } + // }); + // } + }; + + Visitor::run(v->entry, check); + v->eachPromise( + [&](Promise* p) { Visitor::run(p->entry, check); }); + + }); + }); + + + }; void Compiler::optimizeClosureVersion(ClosureVersion* v) { diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index 46bceac39..fbe3f411d 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -317,7 +317,7 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, } call->lastSeen = nullptr; - if (version->owner()->hasBeenCloned || + if (version->isClone || version->staticCallRefCount > 1) { newVersion = target->cloneWithAssumptions( @@ -327,15 +327,16 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, }); if (newVersion != version) { - version->owner()->hasBeenCloned = true; + newVersion->isClone = true; - // newVersion->staticCallRefCount++; - // call->lastSeen = newVersion; - // version->staticCallRefCount--; + + call->lastSeen = newVersion; + version->staticCallRefCount--; } } else { + newVersion = target->replaceWithAssumptions( version, availableAssumptions, updateVersionWithNewAssumptions); diff --git a/rir/src/compiler/pir/closure.h b/rir/src/compiler/pir/closure.h index 5ea0bdea7..52cc56a40 100644 --- a/rir/src/compiler/pir/closure.h +++ b/rir/src/compiler/pir/closure.h @@ -51,7 +51,6 @@ class Closure { Context userContext_; public: - bool hasBeenCloned = false; bool matchesUserContext(Context c) const { return c.smaller(this->userContext_); diff --git a/rir/src/compiler/pir/closure_version.h b/rir/src/compiler/pir/closure_version.h index 72efd85c9..d630bc23e 100644 --- a/rir/src/compiler/pir/closure_version.h +++ b/rir/src/compiler/pir/closure_version.h @@ -60,6 +60,9 @@ class ClosureVersion : public Code { friend class Closure; public: + + bool isClone = false; + ClosureVersion* clone(const Context& newContext); const Context& context() const { return optimizationContext_; } From 319be18870502b778aa57ed3297623cfe82c913b Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Tue, 3 May 2022 09:24:39 +0000 Subject: [PATCH 04/12] added isClone flag at closureVersion level --- rir/src/compiler/compiler.cpp | 14 +------------- rir/src/compiler/opt/eager_calls.cpp | 5 +---- rir/src/compiler/pir/closure_version.h | 1 - 3 files changed, 2 insertions(+), 18 deletions(-) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index 46785e613..79e167134 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -243,8 +243,6 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { std::unordered_map> reachable; bool changed = true; - - auto found = [&](ClosureVersion* v) { if (!v) return; @@ -286,7 +284,6 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { log.warn(msg.str()); } - found(call->tryDispatch()); found(call->tryOptimisticDispatch()); found(call->hint); @@ -326,8 +323,6 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { for (auto e : toErase) e.first->erase(e.second); - - // reset refCount state m->eachPirClosure([&](Closure* c) { c->eachVersion([&](ClosureVersion* v) { @@ -337,9 +332,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { }); m->eachPirClosure([&](Closure* c) { - c->eachVersion([&](ClosureVersion* v) { - auto check = [&](Instruction* i) { if (auto call = StaticCall::Cast(i)) { @@ -361,14 +354,9 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { }; Visitor::run(v->entry, check); - v->eachPromise( - [&](Promise* p) { Visitor::run(p->entry, check); }); - + v->eachPromise([&](Promise* p) { Visitor::run(p->entry, check); }); }); }); - - - }; void Compiler::optimizeClosureVersion(ClosureVersion* v) { diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index fbe3f411d..2854de62d 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -317,8 +317,7 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, } call->lastSeen = nullptr; - if (version->isClone || - version->staticCallRefCount > 1) { + if (version->isClone || version->staticCallRefCount > 1) { newVersion = target->cloneWithAssumptions( version, availableAssumptions, @@ -329,14 +328,12 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (newVersion != version) { newVersion->isClone = true; - call->lastSeen = newVersion; version->staticCallRefCount--; } } else { - newVersion = target->replaceWithAssumptions( version, availableAssumptions, updateVersionWithNewAssumptions); diff --git a/rir/src/compiler/pir/closure_version.h b/rir/src/compiler/pir/closure_version.h index d630bc23e..5be9049b9 100644 --- a/rir/src/compiler/pir/closure_version.h +++ b/rir/src/compiler/pir/closure_version.h @@ -60,7 +60,6 @@ class ClosureVersion : public Code { friend class Closure; public: - bool isClone = false; ClosureVersion* clone(const Context& newContext); From f087559473ce89252a7a4fa2545c65c65910fa9a Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Tue, 3 May 2022 12:33:54 +0000 Subject: [PATCH 05/12] minor fix --- rir/src/compiler/compiler.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index 79e167134..d819b2e9f 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -337,8 +337,10 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { if (auto call = StaticCall::Cast(i)) { auto dispatched = call->tryDispatch(); - call->lastSeen = dispatched; - dispatched->staticCallRefCount++; + if (dispatched) { + call->lastSeen = dispatched; + dispatched->staticCallRefCount++; + } } // else if (auto call = CallInstruction::CastCall(i)) { // if (auto cls = call->tryGetCls()) From 10233b51c95de44b69b1f78e91c7069e2e49ef37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Krynski?= <49732803+skrynski@users.noreply.github.com> Date: Tue, 3 May 2022 17:03:00 +0200 Subject: [PATCH 06/12] minor fix 2 - lastSeen=nullptr; --- rir/src/compiler/compiler.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index d819b2e9f..fcdd44cd5 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -336,7 +336,8 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { auto check = [&](Instruction* i) { if (auto call = StaticCall::Cast(i)) { - auto dispatched = call->tryDispatch(); + call->lastSeen = nullptr; + auto dispatched = call->tryDispatch(); if (dispatched) { call->lastSeen = dispatched; dispatched->staticCallRefCount++; From 2b06a5015314d0801998a2f8c7e3139a072f25d5 Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Tue, 3 May 2022 16:24:54 +0000 Subject: [PATCH 07/12] run again --- rir/src/compiler/compiler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index fcdd44cd5..3761863d2 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -336,8 +336,8 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { auto check = [&](Instruction* i) { if (auto call = StaticCall::Cast(i)) { - call->lastSeen = nullptr; - auto dispatched = call->tryDispatch(); + call->lastSeen = nullptr; + auto dispatched = call->tryDispatch(); if (dispatched) { call->lastSeen = dispatched; dispatched->staticCallRefCount++; From 8c27e44bd4893f00788e413340d969ced5b62638 Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Wed, 4 May 2022 06:35:18 +0000 Subject: [PATCH 08/12] improve refCount reset after versions' GC --- rir/src/compiler/compiler.cpp | 81 +++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index 3761863d2..bd10f9ce5 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -239,6 +239,49 @@ void Compiler::compileClosure(Closure* closure, rir::Function* optFunction, bool MEASURE_COMPILER_PERF = getenv("PIR_MEASURE_COMPILER") ? true : false; +static void resetVersionsRefCountState(Module* m) { + + m->eachPirClosure([&](Closure* c) { + c->eachVersion([&](ClosureVersion* v) { + v->staticCallRefCount = 0; + v->isClone = false; + }); + }); + + m->eachPirClosure([&](Closure* c) { + c->eachVersion([&](ClosureVersion* v) { + auto check = [&](Instruction* i) { + if (auto call = StaticCall::Cast(i)) { + + call->lastSeen = nullptr; + + if (auto dispatchedVersion = call->tryDispatch()) { + call->lastSeen = dispatchedVersion; + dispatchedVersion->staticCallRefCount++; + } + } else if (auto call = CallInstruction::CastCall(i)) { + if (auto cls = call->tryGetCls()) { + + if (auto dispatchedVersion = call->tryDispatch(cls)) { + dispatchedVersion->staticCallRefCount++; + } + } + } + }; + + Visitor::run(v->entry, check); + v->eachPromise([&](Promise* p) { Visitor::run(p->entry, check); }); + }); + }); + + m->eachPirClosure([&](Closure* c) { + c->eachVersion([&](ClosureVersion* v) { + if (v->staticCallRefCount == 0) + v->staticCallRefCount = 1; + }); + }); +} + static void findUnreachable(Module* m, Log& log, const std::string& where) { std::unordered_map> reachable; bool changed = true; @@ -323,43 +366,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { for (auto e : toErase) e.first->erase(e.second); - // reset refCount state - m->eachPirClosure([&](Closure* c) { - c->eachVersion([&](ClosureVersion* v) { - v->staticCallRefCount = 0; - v->isClone = false; - }); - }); - - m->eachPirClosure([&](Closure* c) { - c->eachVersion([&](ClosureVersion* v) { - auto check = [&](Instruction* i) { - if (auto call = StaticCall::Cast(i)) { - - call->lastSeen = nullptr; - auto dispatched = call->tryDispatch(); - if (dispatched) { - call->lastSeen = dispatched; - dispatched->staticCallRefCount++; - } - } - // else if (auto call = CallInstruction::CastCall(i)) { - // if (auto cls = call->tryGetCls()) - // found(call->tryDispatch(cls)); - // } else { - // i->eachArg([&](Value* v) { - // if (auto mk = MkCls::Cast(i)) { - // if (mk->tryGetCls()) - // mk->tryGetCls()->eachVersion(found); - // } - // }); - // } - }; - - Visitor::run(v->entry, check); - v->eachPromise([&](Promise* p) { Visitor::run(p->entry, check); }); - }); - }); + resetVersionsRefCountState(m); }; void Compiler::optimizeClosureVersion(ClosureVersion* v) { From 33df2e756fed7056f664f539a876de7194369c92 Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Fri, 6 May 2022 09:37:39 +0000 Subject: [PATCH 09/12] sanity check --- rir/src/compiler/opt/eager_calls.cpp | 2 +- rir/src/compiler/pir/instruction.cpp | 24 ++++++++++++------------ rir/src/compiler/pir/instruction.h | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index 2854de62d..1c5ba6d3b 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -315,7 +315,7 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (version != call->lastSeen) { version->staticCallRefCount++; } - + version->isClone = true; call->lastSeen = nullptr; if (version->isClone || version->staticCallRefCount > 1) { diff --git a/rir/src/compiler/pir/instruction.cpp b/rir/src/compiler/pir/instruction.cpp index 7fa884d12..7fb3106c3 100644 --- a/rir/src/compiler/pir/instruction.cpp +++ b/rir/src/compiler/pir/instruction.cpp @@ -1275,18 +1275,18 @@ StaticCall::StaticCall(Value* callerEnv, Closure* cls, Context givenContext, assert(tryDispatch()); } -Instruction* StaticCall::clone() const { - auto r = InstructionImplementation::clone(); - - auto sc = StaticCall::Cast(r); - sc->lastSeen = nullptr; - auto target = sc->tryDispatch(); - if (target) { - sc->lastSeen = target; - target->staticCallRefCount++; - } - return sc; -} +// Instruction* StaticCall::clone() const { +// auto r = InstructionImplementation::clone(); + +// auto sc = StaticCall::Cast(r); +// sc->lastSeen = nullptr; +// auto target = sc->tryDispatch(); +// if (target) { +// sc->lastSeen = target; +// target->staticCallRefCount++; +// } +// return sc; +// } PirType StaticCall::inferType(const GetType& getType) const { auto t = PirType::bottom(); diff --git a/rir/src/compiler/pir/instruction.h b/rir/src/compiler/pir/instruction.h index 29bee9368..f349760f6 100644 --- a/rir/src/compiler/pir/instruction.h +++ b/rir/src/compiler/pir/instruction.h @@ -2281,7 +2281,7 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { size_t nCallArgs() const override { return nargs() - 3; } - Instruction* clone() const override; + // Instruction* clone() const override; void eachNamedCallArg(const NamedArgumentValueIterator& it) const override { for (size_t i = 0; i < nCallArgs(); ++i) From 73f6785f3d00341d08d687807618c39cd0e91aca Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Tue, 10 May 2022 15:38:43 +0000 Subject: [PATCH 10/12] update version refCount in StaticCall constructor --- rir/src/compiler/opt/eager_calls.cpp | 2 +- rir/src/compiler/opt/match_call_args.cpp | 24 +++++--------- rir/src/compiler/pir/instruction.cpp | 41 +++++++++++++++--------- rir/src/compiler/pir/instruction.h | 9 ++++-- rir/src/compiler/rir2pir/rir2pir.cpp | 7 ++-- 5 files changed, 43 insertions(+), 40 deletions(-) diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index 1c5ba6d3b..2854de62d 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -315,7 +315,7 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (version != call->lastSeen) { version->staticCallRefCount++; } - version->isClone = true; + call->lastSeen = nullptr; if (version->isClone || version->staticCallRefCount > 1) { diff --git a/rir/src/compiler/opt/match_call_args.cpp b/rir/src/compiler/opt/match_call_args.cpp index c9374761e..b0c190cf0 100644 --- a/rir/src/compiler/opt/match_call_args.cpp +++ b/rir/src/compiler/opt/match_call_args.cpp @@ -242,35 +242,27 @@ bool MatchCallArgs::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (auto c = call) { assert(!usemethodTarget); auto cls = c->cls()->followCastsAndForce(); - target->staticCallRefCount++; - auto nc = new StaticCall( - c->env(), target->owner(), asmpt, matchedArgs, - argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); - nc->lastSeen = target; + auto nc = new StaticCall( + c->env(), target, asmpt, matchedArgs, argOrderOrig, + c->frameStateOrTs(), c->srcIdx, cls); (*ip)->replaceUsesAndSwapWith(nc, ip); } else if (auto c = namedCall) { assert(!usemethodTarget); auto cls = c->cls()->followCastsAndForce(); - target->staticCallRefCount++; - auto nc = new StaticCall( - c->env(), target->owner(), asmpt, matchedArgs, - argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); - nc->lastSeen = target; + auto nc = new StaticCall( + c->env(), target, asmpt, matchedArgs, argOrderOrig, + c->frameStateOrTs(), c->srcIdx, cls); (*ip)->replaceUsesAndSwapWith(nc, ip); } else if (auto c = staticCall) { assert(usemethodTarget); auto cls = cmp.module->c(usemethodTarget); - target->staticCallRefCount++; - auto nc = new StaticCall( - c->env(), target->owner(), asmpt, matchedArgs, - argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); - - nc->lastSeen = target; + c->env(), target, asmpt, matchedArgs, argOrderOrig, + c->frameStateOrTs(), c->srcIdx, cls); (*ip)->replaceUsesAndSwapWith(nc, ip); } else { diff --git a/rir/src/compiler/pir/instruction.cpp b/rir/src/compiler/pir/instruction.cpp index 7fb3106c3..c7ef2e350 100644 --- a/rir/src/compiler/pir/instruction.cpp +++ b/rir/src/compiler/pir/instruction.cpp @@ -1251,12 +1251,16 @@ NamedCall::NamedCall(Value* callerEnv, Value* fun, } } -StaticCall::StaticCall(Value* callerEnv, Closure* cls, Context givenContext, - const std::vector& args, +StaticCall::StaticCall(Value* callerEnv, ClosureVersion* clsVersion, + Context givenContext, const std::vector& args, const ArglistOrder::CallArglistOrder& argOrderOrig, Value* fs, unsigned srcIdx, Value* runtimeClosure) : VarLenInstructionWithEnvSlot(PirType::val(), callerEnv, srcIdx), - cls_(cls), argOrderOrig(argOrderOrig), givenContext(givenContext) { + cls_(clsVersion->owner()), argOrderOrig(argOrderOrig), + givenContext(givenContext) { + + auto cls = clsVersion->owner(); + assert(cls->nargs() >= args.size()); assert(fs); pushArg(fs, NativeType::frameState); @@ -1272,21 +1276,28 @@ StaticCall::StaticCall(Value* callerEnv, Closure* cls, Context givenContext, PirType() | RType::prom | RType::missing | PirType::val()); } } - assert(tryDispatch()); + + assert(tryDispatch() == clsVersion); + + // Update version-refCount fields + clsVersion->staticCallRefCount++; + lastSeen = clsVersion; } -// Instruction* StaticCall::clone() const { -// auto r = InstructionImplementation::clone(); +void StaticCall::updateVersionRefCount() { + lastSeen = nullptr; + auto target = tryDispatch(); + if (target) { + lastSeen = target; + target->staticCallRefCount++; + } +} -// auto sc = StaticCall::Cast(r); -// sc->lastSeen = nullptr; -// auto target = sc->tryDispatch(); -// if (target) { -// sc->lastSeen = target; -// target->staticCallRefCount++; -// } -// return sc; -// } +Instruction* StaticCall::clone() const { + auto sc = StaticCall::Cast(InstructionImplementation::clone()); + sc->updateVersionRefCount(); + return sc; +} PirType StaticCall::inferType(const GetType& getType) const { auto t = PirType::bottom(); diff --git a/rir/src/compiler/pir/instruction.h b/rir/src/compiler/pir/instruction.h index f349760f6..da13bdc1a 100644 --- a/rir/src/compiler/pir/instruction.h +++ b/rir/src/compiler/pir/instruction.h @@ -2264,8 +2264,8 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { ArglistOrder::CallArglistOrder argOrderOrig; public: - StaticCall(Value * callerEnv, Closure * cls, Context givenContext, - const std::vector& args, + StaticCall(Value * callerEnv, ClosureVersion * clsVersion, + Context givenContext, const std::vector& args, const ArglistOrder::CallArglistOrder& argOrderOrig, Value* fs, unsigned srcIdx, Value* runtimeClosure = Tombstone::closure()); @@ -2281,7 +2281,7 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { size_t nCallArgs() const override { return nargs() - 3; } - // Instruction* clone() const override; + Instruction* clone() const override; void eachNamedCallArg(const NamedArgumentValueIterator& it) const override { for (size_t i = 0; i < nCallArgs(); ++i) @@ -2334,6 +2334,9 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { assert((res & minimal) == minimal); return res; } + + private: + void updateVersionRefCount(); }; typedef SEXP (*CCODE)(SEXP, SEXP, SEXP, SEXP); diff --git a/rir/src/compiler/rir2pir/rir2pir.cpp b/rir/src/compiler/rir2pir/rir2pir.cpp index 1a930ad6f..ec5cdfebc 100644 --- a/rir/src/compiler/rir2pir/rir2pir.cpp +++ b/rir/src/compiler/rir2pir/rir2pir.cpp @@ -931,18 +931,15 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, assert(!inlining()); auto fs = insert.registerFrameState(srcCode, nextPos, stack, inPromise()); - f->staticCallRefCount++; + auto cl = insert( - new StaticCall(insert.env, f->owner(), given, matchedArgs, + new StaticCall(insert.env, f, given, matchedArgs, std::move(argOrderOrig), fs, ast, f->owner()->closureEnv() == Env::notClosed() ? guardedCallee : Tombstone::closure())); - cl->lastSeen = f; - cl->effects.set(Effect::DependsOnAssume); - push(cl); auto innerc = MkCls::Cast(guardedCallee->followCastsAndForce()); From abc4072a8dff327941383c15632696aba074f28b Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Wed, 11 May 2022 10:19:14 +0000 Subject: [PATCH 11/12] refactoring StaticCall constructor --- rir/src/compiler/opt/match_call_args.cpp | 16 ++++++++++------ rir/src/compiler/pir/instruction.cpp | 22 ++++++++++++---------- rir/src/compiler/pir/instruction.h | 7 ++++--- rir/src/compiler/rir2pir/rir2pir.cpp | 9 ++++++--- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/rir/src/compiler/opt/match_call_args.cpp b/rir/src/compiler/opt/match_call_args.cpp index b0c190cf0..e1e854e9a 100644 --- a/rir/src/compiler/opt/match_call_args.cpp +++ b/rir/src/compiler/opt/match_call_args.cpp @@ -239,13 +239,15 @@ bool MatchCallArgs::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (staticallyArgmatched && target) { anyChange = true; + auto emptyAfterBlock = [](StaticCall* call) {}; if (auto c = call) { assert(!usemethodTarget); auto cls = c->cls()->followCastsAndForce(); auto nc = new StaticCall( - c->env(), target, asmpt, matchedArgs, argOrderOrig, - c->frameStateOrTs(), c->srcIdx, cls); + c->env(), target->owner(), asmpt, matchedArgs, + argOrderOrig, c->frameStateOrTs(), c->srcIdx, + emptyAfterBlock, cls); (*ip)->replaceUsesAndSwapWith(nc, ip); } else if (auto c = namedCall) { @@ -253,16 +255,18 @@ bool MatchCallArgs::apply(Compiler& cmp, ClosureVersion* cls, Code* code, auto cls = c->cls()->followCastsAndForce(); auto nc = new StaticCall( - c->env(), target, asmpt, matchedArgs, argOrderOrig, - c->frameStateOrTs(), c->srcIdx, cls); + c->env(), target->owner(), asmpt, matchedArgs, + argOrderOrig, c->frameStateOrTs(), c->srcIdx, + emptyAfterBlock, cls); (*ip)->replaceUsesAndSwapWith(nc, ip); } else if (auto c = staticCall) { assert(usemethodTarget); auto cls = cmp.module->c(usemethodTarget); auto nc = new StaticCall( - c->env(), target, asmpt, matchedArgs, argOrderOrig, - c->frameStateOrTs(), c->srcIdx, cls); + c->env(), target->owner(), asmpt, matchedArgs, + argOrderOrig, c->frameStateOrTs(), c->srcIdx, + emptyAfterBlock, cls); (*ip)->replaceUsesAndSwapWith(nc, ip); } else { diff --git a/rir/src/compiler/pir/instruction.cpp b/rir/src/compiler/pir/instruction.cpp index c7ef2e350..a6e50bcc3 100644 --- a/rir/src/compiler/pir/instruction.cpp +++ b/rir/src/compiler/pir/instruction.cpp @@ -1251,15 +1251,15 @@ NamedCall::NamedCall(Value* callerEnv, Value* fun, } } -StaticCall::StaticCall(Value* callerEnv, ClosureVersion* clsVersion, - Context givenContext, const std::vector& args, +StaticCall::StaticCall(Value* callerEnv, Closure* cls, Context givenContext, + const std::vector& args, const ArglistOrder::CallArglistOrder& argOrderOrig, - Value* fs, unsigned srcIdx, Value* runtimeClosure) - : VarLenInstructionWithEnvSlot(PirType::val(), callerEnv, srcIdx), - cls_(clsVersion->owner()), argOrderOrig(argOrderOrig), - givenContext(givenContext) { + Value* fs, unsigned srcIdx, + const std::function& after, + Value* runtimeClosure) - auto cls = clsVersion->owner(); + : VarLenInstructionWithEnvSlot(PirType::val(), callerEnv, srcIdx), + cls_(cls), argOrderOrig(argOrderOrig), givenContext(givenContext) { assert(cls->nargs() >= args.size()); assert(fs); @@ -1277,11 +1277,13 @@ StaticCall::StaticCall(Value* callerEnv, ClosureVersion* clsVersion, } } - assert(tryDispatch() == clsVersion); + after(this); + auto dispatched = tryDispatch(); + assert(dispatched); // Update version-refCount fields - clsVersion->staticCallRefCount++; - lastSeen = clsVersion; + dispatched->staticCallRefCount++; + lastSeen = dispatched; } void StaticCall::updateVersionRefCount() { diff --git a/rir/src/compiler/pir/instruction.h b/rir/src/compiler/pir/instruction.h index da13bdc1a..d24511bff 100644 --- a/rir/src/compiler/pir/instruction.h +++ b/rir/src/compiler/pir/instruction.h @@ -2264,10 +2264,11 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { ArglistOrder::CallArglistOrder argOrderOrig; public: - StaticCall(Value * callerEnv, ClosureVersion * clsVersion, - Context givenContext, const std::vector& args, + StaticCall(Value * callerEnv, Closure * cls, Context givenContext, + const std::vector& args, const ArglistOrder::CallArglistOrder& argOrderOrig, Value* fs, - unsigned srcIdx, Value* runtimeClosure = Tombstone::closure()); + unsigned srcIdx, const std::function& after, + Value* runtimeClosure = Tombstone::closure()); ClosureVersion* lastSeen = nullptr; Context givenContext; diff --git a/rir/src/compiler/rir2pir/rir2pir.cpp b/rir/src/compiler/rir2pir/rir2pir.cpp index ec5cdfebc..253eee83d 100644 --- a/rir/src/compiler/rir2pir/rir2pir.cpp +++ b/rir/src/compiler/rir2pir/rir2pir.cpp @@ -932,14 +932,17 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, auto fs = insert.registerFrameState(srcCode, nextPos, stack, inPromise()); + auto after = [](StaticCall* call) { + call->effects.set(Effect::DependsOnAssume); + }; + auto cl = insert( - new StaticCall(insert.env, f, given, matchedArgs, - std::move(argOrderOrig), fs, ast, + new StaticCall(insert.env, f->owner(), given, matchedArgs, + std::move(argOrderOrig), fs, ast, after, f->owner()->closureEnv() == Env::notClosed() ? guardedCallee : Tombstone::closure())); - cl->effects.set(Effect::DependsOnAssume); push(cl); auto innerc = MkCls::Cast(guardedCallee->followCastsAndForce()); From a185ad4ae8db2a0543a12164674bb19c06aa5543 Mon Sep 17 00:00:00 2001 From: Sebastian Krynski Date: Wed, 11 May 2022 16:51:52 +0000 Subject: [PATCH 12/12] more refactorings --- rir/src/compiler/compiler.cpp | 8 +-- rir/src/compiler/opt/eager_calls.cpp | 75 ++++++++++++++++------------ rir/src/compiler/pir/instruction.h | 4 +- 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index bd10f9ce5..c4fee994b 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -252,13 +252,7 @@ static void resetVersionsRefCountState(Module* m) { c->eachVersion([&](ClosureVersion* v) { auto check = [&](Instruction* i) { if (auto call = StaticCall::Cast(i)) { - - call->lastSeen = nullptr; - - if (auto dispatchedVersion = call->tryDispatch()) { - call->lastSeen = dispatchedVersion; - dispatchedVersion->staticCallRefCount++; - } + call->updateVersionRefCount(); } else if (auto call = CallInstruction::CastCall(i)) { if (auto cls = call->tryGetCls()) { diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index 2854de62d..bee4727ef 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -15,6 +15,46 @@ namespace rir { namespace pir { +static ClosureVersion* +cloneOrReplaceVersion(Closure* target, ClosureVersion* version, + StaticCall* call, Context assumptions, + const std::function& + updateVersionWithNewAssumptions) { + if (version != call->lastSeen) { + version->staticCallRefCount++; + } + + ClosureVersion* newVersion; + + call->lastSeen = nullptr; + if (version->isClone || version->staticCallRefCount > 1) { + + newVersion = target->cloneWithAssumptions( + version, assumptions, [&](ClosureVersion* newCls) { + updateVersionWithNewAssumptions(newCls); + }); + + if (newVersion != version) { + newVersion->isClone = true; + + call->lastSeen = newVersion; + version->staticCallRefCount--; + } + + } else { + + newVersion = target->replaceWithAssumptions( + version, assumptions, updateVersionWithNewAssumptions); + + call->lastSeen = newVersion; + if (newVersion->staticCallRefCount == 0) { + newVersion->staticCallRefCount = 1; + } + } + + return newVersion; +} + bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, AbstractLog& log, size_t) const { AvailableCheckpoints checkpoint(cls, code, log); @@ -288,7 +328,6 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, // version explosion. if (availableAssumptions.isImproving(version)) { - ClosureVersion* newVersion; auto updateVersionWithNewAssumptions = [&](ClosureVersion* newCls) { @@ -312,37 +351,9 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, }); }; - if (version != call->lastSeen) { - version->staticCallRefCount++; - } - - call->lastSeen = nullptr; - if (version->isClone || version->staticCallRefCount > 1) { - - newVersion = target->cloneWithAssumptions( - version, availableAssumptions, - [&](ClosureVersion* newCls) { - updateVersionWithNewAssumptions(newCls); - }); - - if (newVersion != version) { - newVersion->isClone = true; - - call->lastSeen = newVersion; - version->staticCallRefCount--; - } - - } else { - - newVersion = target->replaceWithAssumptions( - version, availableAssumptions, - updateVersionWithNewAssumptions); - - call->lastSeen = newVersion; - if (newVersion->staticCallRefCount == 0) { - newVersion->staticCallRefCount = 1; - } - } + auto newVersion = cloneOrReplaceVersion( + target, version, call, availableAssumptions, + updateVersionWithNewAssumptions); call->hint = newVersion; assert(call->tryDispatch() == newVersion); diff --git a/rir/src/compiler/pir/instruction.h b/rir/src/compiler/pir/instruction.h index d24511bff..ec602f18a 100644 --- a/rir/src/compiler/pir/instruction.h +++ b/rir/src/compiler/pir/instruction.h @@ -2283,6 +2283,7 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { size_t nCallArgs() const override { return nargs() - 3; } Instruction* clone() const override; + void updateVersionRefCount(); void eachNamedCallArg(const NamedArgumentValueIterator& it) const override { for (size_t i = 0; i < nCallArgs(); ++i) @@ -2335,9 +2336,6 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { assert((res & minimal) == minimal); return res; } - - private: - void updateVersionRefCount(); }; typedef SEXP (*CCODE)(SEXP, SEXP, SEXP, SEXP);