Skip to content

Commit

Permalink
Factored out match lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
LPeter1997 committed Oct 5, 2023
1 parent a70217a commit 6d2ca42
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 84 deletions.
85 changes: 1 addition & 84 deletions src/Draco.Compiler/Internal/Lowering/LocalRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Draco.Compiler.Internal.Lowering;
/// <summary>
/// Performs local rewrites of the source code.
/// </summary>
internal partial class LocalRewriter : BoundTreeRewriter
internal sealed partial class LocalRewriter : BoundTreeRewriter
{
/// <summary>
/// Represents a value that was temporarily stored.
Expand Down Expand Up @@ -201,61 +201,6 @@ public override BoundNode VisitIfExpression(BoundIfExpression node)
value: LocalExpression(result));
}

public override BoundNode VisitMatchExpression(BoundMatchExpression node)
{
// match (matchedExpr) {
// pattern1 if (guard1) -> value1;
// pattern2 if (guard2) -> value2;
// ...
// }
//
// =>
//
// {
// val tmp = matchedExpr;
// if (matches-pattern1(tmp) and guard1) value1
// else if (matches-pattern2(tmp) and guard2) value2
// ...
// }

// Evaluate the matched value as a local to not duplicate side-effects
var matchedValue = (BoundExpression)node.MatchedValue.Accept(this);
var tmp = this.StoreTemporary(matchedValue);

var conditionValuePairs = new List<(BoundExpression Condition, BoundExpression Value)>();
foreach (var arm in node.MatchArms)
{
var patternMatcher = this.ConstructPatternToMatch(arm.Pattern, tmp.Reference);
var guard = arm.Guard ?? this.LiteralExpression(true);

var condition = AndExpression(patternMatcher, guard);

conditionValuePairs.Add((condition, arm.Value));
}

// NOTE: We do an r-fold to nest if-expressions to the right
conditionValuePairs.Reverse();
// TODO: The match-expr might be empty!
var lastPair = conditionValuePairs[0];
var ifChain = (BoundExpression)conditionValuePairs
.Skip(1)
.Aggregate(
IfExpression(
lastPair.Condition,
lastPair.Value,
BoundTreeFactory.LiteralExpression(null, node.TypeRequired),
node.TypeRequired),
(acc, arm) => IfExpression(arm.Condition, arm.Value, acc, node.TypeRequired))
.Accept(this);

return BlockExpression(
locals: tmp.Symbol is null
? ImmutableArray<LocalSymbol>.Empty
: ImmutableArray.Create(tmp.Symbol),
statements: ImmutableArray.Create(tmp.Assignment),
ifChain);
}

public override BoundNode VisitWhileExpression(BoundWhileExpression node)
{
// while (condition)
Expand Down Expand Up @@ -634,34 +579,6 @@ public override BoundNode VisitIndexGetExpression(BoundIndexGetExpression node)
arguments: args);
}

private BoundExpression ConstructPatternToMatch(BoundPattern pattern, BoundExpression matchedValue) => pattern switch
{
BoundDiscardPattern discard => this.ConstructPatternToMatch(discard, matchedValue),
BoundLiteralPattern literal => this.ConstructPatternToMatch(literal, matchedValue),
_ => throw new ArgumentOutOfRangeException(nameof(pattern)),
};

private BoundExpression ConstructPatternToMatch(BoundDiscardPattern pattern, BoundExpression matchedValue) =>
// _ matches expr
//
// =>
//
// true
this.LiteralExpression(true);

private BoundExpression ConstructPatternToMatch(BoundLiteralPattern pattern, BoundExpression matchedValue) =>
// N matches expr
//
// =>
//
// object.Equals(N, expr)
CallExpression(
receiver: null,
method: this.WellKnownTypes.SystemObject_Equals,
arguments: ImmutableArray.Create(
this.LiteralExpression(pattern.Value),
matchedValue));

// Utility to store an expression to a temporary variable
private TemporaryStorage StoreTemporary(BoundExpression expr)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Draco.Compiler.Internal.BoundTree;
using Draco.Compiler.Internal.Symbols;
using static Draco.Compiler.Internal.BoundTree.BoundTreeFactory;

namespace Draco.Compiler.Internal.Lowering;

internal sealed partial class LocalRewriter
{
public override BoundNode VisitMatchExpression(BoundMatchExpression node)
{
// match (matchedExpr) {
// pattern1 if (guard1) -> value1;
// pattern2 if (guard2) -> value2;
// ...
// }
//
// =>
//
// {
// val tmp = matchedExpr;
// if (matches-pattern1(tmp) and guard1) value1
// else if (matches-pattern2(tmp) and guard2) value2
// ...
// }

// Evaluate the matched value as a local to not duplicate side-effects
var matchedValue = (BoundExpression)node.MatchedValue.Accept(this);
var tmp = this.StoreTemporary(matchedValue);

var conditionValuePairs = new List<(BoundExpression Condition, BoundExpression Value)>();
foreach (var arm in node.MatchArms)
{
var patternMatcher = this.ConstructPatternToMatch(arm.Pattern, tmp.Reference);
var guard = arm.Guard ?? this.LiteralExpression(true);

var condition = AndExpression(patternMatcher, guard);

conditionValuePairs.Add((condition, arm.Value));
}

// NOTE: We do an r-fold to nest if-expressions to the right
conditionValuePairs.Reverse();
// TODO: The match-expr might be empty!
var lastPair = conditionValuePairs[0];
var ifChain = (BoundExpression)conditionValuePairs
.Skip(1)
.Aggregate(
IfExpression(
lastPair.Condition,
lastPair.Value,
BoundTreeFactory.LiteralExpression(null, node.TypeRequired),
node.TypeRequired),
(acc, arm) => IfExpression(arm.Condition, arm.Value, acc, node.TypeRequired))
.Accept(this);

return BlockExpression(
locals: tmp.Symbol is null
? ImmutableArray<LocalSymbol>.Empty
: ImmutableArray.Create(tmp.Symbol),
statements: ImmutableArray.Create(tmp.Assignment),
ifChain);
}

private BoundExpression ConstructPatternToMatch(BoundPattern pattern, BoundExpression matchedValue) => pattern switch
{
BoundDiscardPattern discard => this.ConstructPatternToMatch(discard, matchedValue),
BoundLiteralPattern literal => this.ConstructPatternToMatch(literal, matchedValue),
_ => throw new ArgumentOutOfRangeException(nameof(pattern)),
};

private BoundExpression ConstructPatternToMatch(BoundDiscardPattern pattern, BoundExpression matchedValue) =>
// _ matches expr
//
// =>
//
// true
this.LiteralExpression(true);

private BoundExpression ConstructPatternToMatch(BoundLiteralPattern pattern, BoundExpression matchedValue) =>
// N matches expr
//
// =>
//
// object.Equals(N, expr)
CallExpression(
receiver: null,
method: this.WellKnownTypes.SystemObject_Equals,
arguments: ImmutableArray.Create(
this.LiteralExpression(pattern.Value),
matchedValue));
}

0 comments on commit 6d2ca42

Please sign in to comment.