Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce version cloning #1208

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@
"cfenv": "cpp",
"csignal": "cpp",
"__functional_base_03": "cpp",
"__memory": "cpp"
"__memory": "cpp",
"__bits": "cpp",
"__availability": "cpp"
}
}
41 changes: 41 additions & 0 deletions rir/src/compiler/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,43 @@ void Compiler::compileClosure(Closure* closure, rir::Function* optFunction,

bool MEASURE_COMPILER_PERF = getenv("PIR_MEASURE_COMPILER") ? true : false;

static void resetVersionsRefCountState(Module* m) {

m->eachPirClosure([&](Closure* c) {
c->eachVersion([&](ClosureVersion* v) {
v->staticCallRefCount = 0;
v->isClone = false;
});
});

m->eachPirClosure([&](Closure* c) {
c->eachVersion([&](ClosureVersion* v) {
auto check = [&](Instruction* i) {
if (auto call = StaticCall::Cast(i)) {
call->updateVersionRefCount();
} else if (auto call = CallInstruction::CastCall(i)) {
if (auto cls = call->tryGetCls()) {

if (auto dispatchedVersion = call->tryDispatch(cls)) {
dispatchedVersion->staticCallRefCount++;
}
}
}
};

Visitor::run(v->entry, check);
v->eachPromise([&](Promise* p) { Visitor::run(p->entry, check); });
});
});

m->eachPirClosure([&](Closure* c) {
c->eachVersion([&](ClosureVersion* v) {
if (v->staticCallRefCount == 0)
v->staticCallRefCount = 1;
});
});
}

static void findUnreachable(Module* m, Log& log, const std::string& where) {
std::unordered_map<Closure*, std::unordered_set<Context>> reachable;
bool changed = true;
Expand Down Expand Up @@ -283,6 +320,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) {
i->printRecursive(msg, 2);
log.warn(msg.str());
}

found(call->tryDispatch());
found(call->tryOptimisticDispatch());
found(call->hint);
Expand Down Expand Up @@ -310,6 +348,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) {
m->eachPirClosure([&](Closure* c) {
const auto& reachableVersions = reachable[c];
c->eachVersion([&](ClosureVersion* v) {
// assert(c->getVersion(v->context()) == v);
if (!reachableVersions.count(v->context())) {
toErase.push_back({v->owner(), v->context()});
log.close(v);
Expand All @@ -320,6 +359,8 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) {

for (auto e : toErase)
e.first->erase(e.second);

resetVersionsRefCountState(m);
};

void Compiler::optimizeClosureVersion(ClosureVersion* v) {
Expand Down
3 changes: 3 additions & 0 deletions rir/src/compiler/native/lower_function_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3387,6 +3387,9 @@ void LowerFunctionLLVM::compile() {
if (calli->isReordered())
callId = pushArgReordering(calli->getArgOrderOrig());

if (!target)
assert(target && "target is null!");

if (!target->owner()->hasOriginClosure()) {
setVal(
i, withCallFrame(args, [&]() -> llvm::Value* {
Expand Down
52 changes: 49 additions & 3 deletions rir/src/compiler/opt/eager_calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,46 @@
namespace rir {
namespace pir {

static ClosureVersion*
cloneOrReplaceVersion(Closure* target, ClosureVersion* version,
StaticCall* call, Context assumptions,
const std::function<void(ClosureVersion*)>&
updateVersionWithNewAssumptions) {
if (version != call->lastSeen) {
version->staticCallRefCount++;
}

ClosureVersion* newVersion;

call->lastSeen = nullptr;
if (version->isClone || version->staticCallRefCount > 1) {

newVersion = target->cloneWithAssumptions(
version, assumptions, [&](ClosureVersion* newCls) {
updateVersionWithNewAssumptions(newCls);
});

if (newVersion != version) {
newVersion->isClone = true;

call->lastSeen = newVersion;
version->staticCallRefCount--;
}

} else {

newVersion = target->replaceWithAssumptions(
version, assumptions, updateVersionWithNewAssumptions);

call->lastSeen = newVersion;
if (newVersion->staticCallRefCount == 0) {
newVersion->staticCallRefCount = 1;
}
}

return newVersion;
}

bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code,
AbstractLog& log, size_t) const {
AvailableCheckpoints checkpoint(cls, code, log);
Expand Down Expand Up @@ -287,8 +327,9 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code,
// version. Maybe we should limit this at some point, to avoid
// version explosion.
if (availableAssumptions.isImproving(version)) {
auto newVersion = target->cloneWithAssumptions(
version, availableAssumptions,


auto updateVersionWithNewAssumptions =
[&](ClosureVersion* newCls) {
Visitor::run(newCls->entry, [&](Instruction* i) {
if (auto f = Force::Cast(i)) {
Expand All @@ -308,7 +349,12 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code,
}
}
});
});
};

auto newVersion = cloneOrReplaceVersion(
target, version, call, availableAssumptions,
updateVersionWithNewAssumptions);

call->hint = newVersion;
assert(call->tryDispatch() == newVersion);
ip = next;
Expand Down
3 changes: 3 additions & 0 deletions rir/src/compiler/opt/match_call_args.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,19 @@ bool MatchCallArgs::apply(Compiler& cmp, ClosureVersion* cls, Code* code,

if (staticallyArgmatched && target) {
anyChange = true;

if (auto c = call) {
assert(!usemethodTarget);
auto cls = c->cls()->followCastsAndForce();

auto nc = new StaticCall(
c->env(), target, asmpt, matchedArgs, argOrderOrig,
c->frameStateOrTs(), c->srcIdx, cls);
(*ip)->replaceUsesAndSwapWith(nc, ip);
} else if (auto c = namedCall) {
assert(!usemethodTarget);
auto cls = c->cls()->followCastsAndForce();

auto nc = new StaticCall(
c->env(), target, asmpt, matchedArgs, argOrderOrig,
c->frameStateOrTs(), c->srcIdx, cls);
Expand Down
19 changes: 19 additions & 0 deletions rir/src/compiler/pir/closure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ ClosureVersion* Closure::cloneWithAssumptions(ClosureVersion* version,
return copy;
}

ClosureVersion* Closure::replaceWithAssumptions(ClosureVersion* version,
const Context& asmpt,
const MaybeClsVersion& change) {
assert(versions.count(version->context()) > 0);

auto newCtx = version->context() | asmpt;
if (versions.count(newCtx)) {
return versions.at(newCtx);
}

versions.erase(version->context());

version->setContext(newCtx);

versions[newCtx] = version;
change(version);
return version;
}

ClosureVersion* Closure::findCompatibleVersion(const Context& ctx) const {
// ordered by number of assumptions
for (auto& candidate : versions) {
Expand Down
5 changes: 5 additions & 0 deletions rir/src/compiler/pir/closure.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Closure {
Context userContext_;

public:

bool matchesUserContext(Context c) const {
return c.smaller(this->userContext_);
}
Expand Down Expand Up @@ -90,6 +91,10 @@ class Closure {
const Context& asmpt,
const MaybeClsVersion& change);

ClosureVersion* replaceWithAssumptions(ClosureVersion* cls,
const Context& asmpt,
const MaybeClsVersion& change);

typedef std::function<void(pir::ClosureVersion*)> ClosureVersionIterator;
void eachVersion(ClosureVersionIterator it) const;

Expand Down
8 changes: 7 additions & 1 deletion rir/src/compiler/pir/closure_version.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ class ClosureVersion : public Code {
const bool root;

rir::Function* optFunction;
size_t staticCallRefCount = 0;

private:
Closure* owner_;
std::vector<Promise*> promises_;
const Context optimizationContext_;
Context optimizationContext_;

std::string name_;
std::string nameSuffix_;
Expand All @@ -59,9 +60,14 @@ class ClosureVersion : public Code {
friend class Closure;

public:
bool isClone = false;

ClosureVersion* clone(const Context& newContext);

const Context& context() const { return optimizationContext_; }
void setContext(const Context& newContext) {
optimizationContext_ = newContext;
}

Properties properties;

Expand Down
21 changes: 21 additions & 0 deletions rir/src/compiler/pir/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,7 @@ StaticCall::StaticCall(Value* callerEnv, ClosureVersion* clsVersion,
Context givenContext, const std::vector<Value*>& args,
const ArglistOrder::CallArglistOrder& argOrderOrig,
Value* fs, unsigned srcIdx, Value* runtimeClosure)

: VarLenInstructionWithEnvSlot(PirType::val(), callerEnv, srcIdx),
cls_(clsVersion->owner()), argOrderOrig(argOrderOrig),
givenContext(givenContext) {
Expand All @@ -1276,7 +1277,27 @@ StaticCall::StaticCall(Value* callerEnv, ClosureVersion* clsVersion,
}
}

auto dispatched = tryDispatch();
assert(tryDispatch() == clsVersion);

// Update version-refCount fields
dispatched->staticCallRefCount++;
lastSeen = dispatched;
}

void StaticCall::updateVersionRefCount() {
lastSeen = nullptr;
auto target = tryDispatch();
if (target) {
lastSeen = target;
target->staticCallRefCount++;
}
}

Instruction* StaticCall::clone() const {
auto sc = StaticCall::Cast(InstructionImplementation::clone());
sc->updateVersionRefCount();
return sc;
}

PirType StaticCall::inferType(const GetType& getType) const {
Expand Down
5 changes: 5 additions & 0 deletions rir/src/compiler/pir/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace pir {

class BB;
class Closure;

class Phi;

struct InstrArg {
Expand Down Expand Up @@ -2268,6 +2269,7 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction {
const ArglistOrder::CallArglistOrder& argOrderOrig, Value* fs,
unsigned srcIdx, Value* runtimeClosure = Tombstone::closure());

ClosureVersion* lastSeen = nullptr;
Context givenContext;

ClosureVersion* hint = nullptr;
Expand All @@ -2279,6 +2281,9 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction {

size_t nCallArgs() const override { return nargs() - 3; }

Instruction* clone() const override;
void updateVersionRefCount();

void eachNamedCallArg(const NamedArgumentValueIterator& it) const override {
for (size_t i = 0; i < nCallArgs(); ++i)
it(R_NilValue, arg(i + 2).val());
Expand Down
3 changes: 2 additions & 1 deletion rir/src/compiler/rir2pir/rir2pir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,13 +928,14 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos,
assert(!inlining());
auto fs = insert.registerFrameState(srcCode, nextPos, stack,
inPromise());

auto cl = insert(
new StaticCall(insert.env, f, given, matchedArgs,
std::move(argOrderOrig), fs, ast,
f->owner()->closureEnv() == Env::notClosed()
? guardedCallee
: Tombstone::closure()));
cl->effects.set(Effect::DependsOnAssume);

push(cl);

auto innerc = MkCls::Cast(guardedCallee->followCastsAndForce());
Expand Down