diff --git a/src/Draco.Compiler/Internal/Lowering/LocalRewriter_MatchExpression.cs b/src/Draco.Compiler/Internal/Lowering/LocalRewriter_MatchExpression.cs index 103fd42c4..b9d3b6284 100644 --- a/src/Draco.Compiler/Internal/Lowering/LocalRewriter_MatchExpression.cs +++ b/src/Draco.Compiler/Internal/Lowering/LocalRewriter_MatchExpression.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics; using System.Linq; using Draco.Compiler.Internal.BoundTree; using Draco.Compiler.Internal.FlowAnalysis; @@ -15,51 +16,87 @@ public override BoundNode VisitMatchExpression(BoundMatchExpression node) { // TODO: Elaborate on what we do here + // Evaluate the matched value as a local to not duplicate side-effects + var matchedValue = (BoundExpression)node.MatchedValue.Accept(this); + var tmp = this.StoreTemporary(matchedValue); + // We build up the relevant arms var arms = node.MatchArms .Select(a => DecisionTree.Arm(a.Pattern, a.Guard, a)) .ToImmutableArray(); // From that we build the decision tree - var decisionTree = DecisionTree.Build(this.IntrinsicSymbols, node.MatchedValue, arms); + var decisionTree = DecisionTree.Build(this.IntrinsicSymbols, tmp.Reference, arms); - // TODO: use it + // Recursively lower each decision node + var decisionNode = this.ConstructMatchNode(decisionTree.Root, node.TypeRequired); - // Evaluate the matched value as a local to not duplicate side-effects - var matchedValue = (BoundExpression)node.MatchedValue.Accept(this); - var tmp = this.StoreTemporary(matchedValue); + var result = BlockExpression( + locals: tmp.Symbol is null + ? ImmutableArray.Empty + : ImmutableArray.Create(tmp.Symbol), + statements: ImmutableArray.Create(tmp.Assignment), + decisionNode); + return result.Accept(this); + } - var conditionValuePairs = new List<(BoundExpression Condition, BoundExpression Value)>(); - foreach (var arm in node.MatchArms) - { - var patternMatcher = this.ConstructPatternToMatch(arm.Pattern, tmp.Reference); - var guard = arm.Guard ?? this.LiteralExpression(true); + private BoundExpression ConstructMatchNode(DecisionTree.INode node, TypeSymbol resultType) + { + // Hit a leaf + if (node.IsAction) return (BoundExpression)node.Action.Value.Accept(this); - var condition = AndExpression(patternMatcher, guard); + // TODO + // Failure + if (node.IsFail) throw new NotImplementedException(); - conditionValuePairs.Add((condition, arm.Value)); + // Fold backwards + var seed = null as BoundExpression; + for (var i = node.Children.Count - 1; i >= 0; --i) + { + var (cond, child) = node.Children[i]; + seed = Fold(seed, cond, child); } - // NOTE: We do an r-fold to nest if-expressions to the right - conditionValuePairs.Reverse(); - // TODO: The match-expr might be empty! - var lastPair = conditionValuePairs[0]; - var ifChain = (BoundExpression)conditionValuePairs - .Skip(1) - .Aggregate( - IfExpression( - lastPair.Condition, - lastPair.Value, - BoundTreeFactory.LiteralExpression(null, node.TypeRequired), - node.TypeRequired), - (acc, arm) => IfExpression(arm.Condition, arm.Value, acc, node.TypeRequired)) - .Accept(this); + Debug.Assert(seed is not null); + return seed; - return BlockExpression( - locals: tmp.Symbol is null - ? ImmutableArray.Empty - : ImmutableArray.Create(tmp.Symbol), - statements: ImmutableArray.Create(tmp.Assignment), - ifChain); + BoundExpression Fold( + BoundExpression? prev, + DecisionTree.Condition condition, + DecisionTree.INode node) + { + var matchedValue = (BoundExpression)node.MatchedValue; + var result = this.ConstructMatchNode(node, resultType); + switch (condition.Pattern, condition.Guard) + { + case (null, null): + { + Debug.Assert(prev is null); + return result; + } + case (BoundPattern pat, null): + { + var matchCondition = this.ConstructPatternToMatch(pat, matchedValue); + return IfExpression( + condition: matchCondition, + then: result, + @else: prev ?? DefaultExpression(), + type: resultType); + } + case (null, BoundExpression cond): + { + return IfExpression( + condition: cond, + then: result, + @else: prev ?? DefaultExpression(), + type: resultType); + } + default: + throw new InvalidOperationException(); + } + } + + // TODO: As a default, we should THROW + static BoundExpression DefaultExpression() => BoundUnitExpression.Default; } private BoundExpression ConstructPatternToMatch(BoundPattern pattern, BoundExpression matchedValue) => pattern switch