Skip to content

Commit

Permalink
Update LocalRewriter_MatchExpression.cs
Browse files Browse the repository at this point in the history
  • Loading branch information
LPeter1997 committed Oct 15, 2023
1 parent 4186b7d commit a4e6a01
Showing 1 changed file with 69 additions and 32 deletions.
101 changes: 69 additions & 32 deletions src/Draco.Compiler/Internal/Lowering/LocalRewriter_MatchExpression.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Draco.Compiler.Internal.BoundTree;
using Draco.Compiler.Internal.FlowAnalysis;
Expand All @@ -15,51 +16,87 @@ public override BoundNode VisitMatchExpression(BoundMatchExpression node)
{
// TODO: Elaborate on what we do here

// Evaluate the matched value as a local to not duplicate side-effects
var matchedValue = (BoundExpression)node.MatchedValue.Accept(this);
var tmp = this.StoreTemporary(matchedValue);

// We build up the relevant arms
var arms = node.MatchArms
.Select(a => DecisionTree.Arm(a.Pattern, a.Guard, a))
.ToImmutableArray();
// From that we build the decision tree
var decisionTree = DecisionTree.Build(this.IntrinsicSymbols, node.MatchedValue, arms);
var decisionTree = DecisionTree.Build(this.IntrinsicSymbols, tmp.Reference, arms);

// TODO: use it
// Recursively lower each decision node
var decisionNode = this.ConstructMatchNode(decisionTree.Root, node.TypeRequired);

// Evaluate the matched value as a local to not duplicate side-effects
var matchedValue = (BoundExpression)node.MatchedValue.Accept(this);
var tmp = this.StoreTemporary(matchedValue);
var result = BlockExpression(
locals: tmp.Symbol is null
? ImmutableArray<LocalSymbol>.Empty
: ImmutableArray.Create(tmp.Symbol),
statements: ImmutableArray.Create(tmp.Assignment),
decisionNode);
return result.Accept(this);
}

var conditionValuePairs = new List<(BoundExpression Condition, BoundExpression Value)>();
foreach (var arm in node.MatchArms)
{
var patternMatcher = this.ConstructPatternToMatch(arm.Pattern, tmp.Reference);
var guard = arm.Guard ?? this.LiteralExpression(true);
private BoundExpression ConstructMatchNode(DecisionTree<BoundMatchArm>.INode node, TypeSymbol resultType)
{
// Hit a leaf
if (node.IsAction) return (BoundExpression)node.Action.Value.Accept(this);

var condition = AndExpression(patternMatcher, guard);
// TODO
// Failure
if (node.IsFail) throw new NotImplementedException();

conditionValuePairs.Add((condition, arm.Value));
// Fold backwards
var seed = null as BoundExpression;
for (var i = node.Children.Count - 1; i >= 0; --i)
{
var (cond, child) = node.Children[i];
seed = Fold(seed, cond, child);
}

// NOTE: We do an r-fold to nest if-expressions to the right
conditionValuePairs.Reverse();
// TODO: The match-expr might be empty!
var lastPair = conditionValuePairs[0];
var ifChain = (BoundExpression)conditionValuePairs
.Skip(1)
.Aggregate(
IfExpression(
lastPair.Condition,
lastPair.Value,
BoundTreeFactory.LiteralExpression(null, node.TypeRequired),
node.TypeRequired),
(acc, arm) => IfExpression(arm.Condition, arm.Value, acc, node.TypeRequired))
.Accept(this);
Debug.Assert(seed is not null);
return seed;

return BlockExpression(
locals: tmp.Symbol is null
? ImmutableArray<LocalSymbol>.Empty
: ImmutableArray.Create(tmp.Symbol),
statements: ImmutableArray.Create(tmp.Assignment),
ifChain);
BoundExpression Fold(
BoundExpression? prev,
DecisionTree<BoundMatchArm>.Condition condition,
DecisionTree<BoundMatchArm>.INode node)
{
var matchedValue = (BoundExpression)node.MatchedValue;
var result = this.ConstructMatchNode(node, resultType);
switch (condition.Pattern, condition.Guard)
{
case (null, null):
{
Debug.Assert(prev is null);
return result;
}
case (BoundPattern pat, null):
{
var matchCondition = this.ConstructPatternToMatch(pat, matchedValue);
return IfExpression(
condition: matchCondition,
then: result,
@else: prev ?? DefaultExpression(),
type: resultType);
}
case (null, BoundExpression cond):
{
return IfExpression(
condition: cond,
then: result,
@else: prev ?? DefaultExpression(),
type: resultType);
}
default:
throw new InvalidOperationException();
}
}

// TODO: As a default, we should THROW
static BoundExpression DefaultExpression() => BoundUnitExpression.Default;
}

private BoundExpression ConstructPatternToMatch(BoundPattern pattern, BoundExpression matchedValue) => pattern switch
Expand Down

0 comments on commit a4e6a01

Please sign in to comment.