Skip to content

Commit

Permalink
Update DecisionTree.cs
Browse files Browse the repository at this point in the history
  • Loading branch information
LPeter1997 committed Oct 4, 2023
1 parent 3ab626c commit a70217a
Showing 1 changed file with 140 additions and 27 deletions.
167 changes: 140 additions & 27 deletions src/Draco.Compiler/Internal/FlowAnalysis/DecisionTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Text;
using System.Threading.Tasks;
using Draco.Compiler.Internal.BoundTree;
using Draco.Compiler.Internal.FlowAnalysis.Domain;
using Draco.Compiler.Internal.Utilities;

namespace Draco.Compiler.Internal.FlowAnalysis;

Expand All @@ -25,12 +27,12 @@ internal sealed class DecisionTree<TAction>
public readonly record struct Arm(BoundPattern Pattern, TAction Action);

/// <summary>
/// Represents a redundant arm.
/// Represents a redundancy.
/// </summary>
/// <param name="CoveredBy">The arm that covers the <paramref name="Redundant"/> one already.</param>
/// <param name="Redundant">The arm that is redundant, because <paramref name="CoveredBy"/> already
/// <param name="CoveredBy">The action that covers the <paramref name="Redundant"/> one already.</param>
/// <param name="Redundant">The action that is redundant, because <paramref name="CoveredBy"/> already
/// matches.</param>
public readonly record struct Redundance(Arm CoveredBy, Arm Redundant);
public readonly record struct Redundance(TAction CoveredBy, TAction Redundant);

/// <summary>
/// A single node in the decision tree.
Expand All @@ -48,10 +50,10 @@ public abstract class Node
public abstract bool IsAction { get; }

/// <summary>
/// The arm that's associated with the node.
/// The action that's associated with the node, in case it's a leaf.
/// </summary>
[MemberNotNullWhen(true, nameof(IsAction))]
public abstract Arm? Arm { get; }
public abstract TAction? Action { get; }

/// <summary>
/// True, if this is a failure node.
Expand Down Expand Up @@ -79,22 +81,71 @@ private sealed class MutableNode : Node
{
// Observers
public override Node? Parent { get; }
public override bool IsAction => this.Arm is not null;
public override Arm? Arm => this.MatchingArm;
public override bool IsAction => this.Action is not null;
public override TAction? Action => this.MutableAction;
public override bool IsFail => this.PatternMatrix.Count == 0;
public override BoundPattern? Counterexample => throw new NotImplementedException();
public override BoundExpression? MatchedOn => throw new NotImplementedException();
public override ImmutableArray<KeyValuePair<BoundPattern, Node>> Children =>
this.builtChildren ??= this.MutableChildren.ToImmutableArray();
this.builtChildren ??= this.MutableChildren
.Select(n => new KeyValuePair<BoundPattern, Node>(n.Key, n.Value))
.ToImmutableArray();
private ImmutableArray<KeyValuePair<BoundPattern, Node>>? builtChildren;

// Mutators
public List<BoundExpression> Arguments { get; }
public List<int> ArgumentOrder { get; }
public List<List<BoundPattern>> PatternMatrix { get; }
public List<Arm> Arms { get; }
public Arm? MatchingArm { get; set; }
public List<KeyValuePair<BoundPattern, Node>> MutableChildren { get; }
public List<TAction> Actions { get; }
public TAction? MutableAction { get; set; }
public List<KeyValuePair<BoundPattern, MutableNode>> MutableChildren { get; } = new();

public MutableNode(
Node? parent,
List<BoundExpression> arguments,
List<List<BoundPattern>> patternMatrix,
List<TAction> actions)
{
this.Parent = parent;
this.Arguments = arguments;
this.ArgumentOrder = Enumerable
.Range(0, this.Arguments.Count)
.ToList();
this.PatternMatrix = patternMatrix;
this.Actions = actions;
}

public void SwapColumns(int i, int j)
{
Swap(this.Arguments, i, j);
Swap(this.ArgumentOrder, i, j);
foreach (var row in this.PatternMatrix) Swap(row, i, j);
}

private static void Swap<T>(List<T> list, int i, int j) => (list[i], list[j]) = (list[j], list[i]);
}

/// <summary>
/// A comparer that compares patterns only in terms of specialization, not involving their
/// arguments, if there are any.
/// </summary>
private sealed class SpecializationComparer : IEqualityComparer<BoundPattern>
{
public static SpecializationComparer Instance { get; } = new();

private SpecializationComparer()
{
}

public bool Equals(BoundPattern? x, BoundPattern? y) => (x, y) switch
{
_ => throw new ArgumentOutOfRangeException(paramName: null, message: "unhandled pair of patterns"),
};

public int GetHashCode([DisallowNull] BoundPattern obj) => obj switch
{
_ => throw new ArgumentOutOfRangeException(nameof(obj)),
};
}

/// <summary>
Expand All @@ -116,38 +167,100 @@ public static DecisionTree<TAction> Build(BoundExpression matchedValue, Immutabl
_ => throw new ArgumentOutOfRangeException(nameof(pattern)),
};

/// <summary>
/// The matched value.
/// </summary>
public BoundExpression MatchedValue { get; }

/// <summary>
/// The arms of the root construct.
/// </summary>
public ImmutableArray<Arm> Arms { get; }

/// <summary>
/// The root node of this tree.
/// </summary>
public Node Root { get; }
public Node Root => this.mutableRoot;
private readonly MutableNode mutableRoot;

/// <summary>
/// All redundancies in the tree.
/// </summary>
public ImmutableArray<Redundance> Redundancies { get; }
public ImmutableArray<Redundance> Redundancies => this.redundancies.ToImmutable();
private readonly ImmutableArray<Redundance>.Builder redundancies = ImmutableArray.CreateBuilder<Redundance>();

/// <summary>
/// True, if this tree is exhaustive.
/// </summary>
public bool IsExhaustive { get; }
public bool IsExhaustive => GraphTraversal
.DepthFirst(this.Root, n => n.Children.Select(c => c.Value))
.All(n => !n.IsFail);

/// <summary>
/// An example of an uncovered pattern, if any.
/// </summary>
public BoundPattern? UncoveredExample { get; }
public BoundPattern? UncoveredExample => throw new NotImplementedException();

private DecisionTree(MutableNode root)
{
this.mutableRoot = root;
}

private DecisionTree()
private void Build(MutableNode node)
{
if (node.IsFail) return;

if (node.PatternMatrix[0].All(MatchesEverything))
{
// This is a succeeding node, set the performed action
node.MutableAction = node.Actions[0];
// The remaining ones are redundant
for (var i = 1; i < node.PatternMatrix.Count; ++i)
{
this.redundancies.Add(new(node.Actions[0], node.Actions[i]));
}
return;
}

// We need to make a decision, bring the column that has refutable entries to the beginning
var firstColWithRefutable = FirstColumnWithRefutableEntry(node);
if (firstColWithRefutable != 0) node.SwapColumns(0, firstColWithRefutable);

// The first column now contains something that is refutable
// Collect all pattern variants that we covered
var coveredPatterns = node.PatternMatrix
.Select(row => row[0])
.Where(p => !MatchesEverything(p))
.ToHashSet(SpecializationComparer.Instance);

// Track if there are any uncovered values in this domain
// TODO
var uncoveredDomain = ValueDomain.CreateDomain(null!, node.PatternMatrix[0][0].Type);

// Specialize for each of these cases
foreach (var pat in coveredPatterns)
{
// Specialize to the pattern
this.Specialize(node, pat);
// We covered the value, subtract
uncoveredDomain.SubtractPattern(pat);
}

// If not complete, do defaulting
if (!uncoveredDomain.IsEmpty) this.Default(node);

// Recurse to children
foreach (var (_, child) in node.MutableChildren) this.Build(child);
}

private MutableNode Specialize(MutableNode node, BoundPattern specializer) =>
throw new NotImplementedException();

private MutableNode Default(MutableNode node) =>
throw new NotImplementedException();

private static int FirstColumnWithRefutableEntry(MutableNode node)
{
for (var col = 0; col < node.PatternMatrix[0].Count; ++col)
{
if (node.PatternMatrix.Any(row => !MatchesEverything(row[col]))) return col;
}

throw new InvalidOperationException("should not happen");
}

private static bool MatchesEverything(BoundPattern pattern) => pattern switch
{
_ => throw new ArgumentOutOfRangeException(nameof(pattern)),
};
}

0 comments on commit a70217a

Please sign in to comment.