Skip to content

Commit

Permalink
Refactor overloading (#413)
Browse files Browse the repository at this point in the history
* Factored out call scoring utilities

* Create CallUtilities.cs

* Locator simplification

* Removed crap that got factored out

* Fixed compilation

* Update Rename.cs

* Lots of shuffling for overload sets

* Merged call logics

* Code cleanup

* Update Draco.ProjectSystem.csproj

* Update Draco.SourceGeneration.csproj
  • Loading branch information
LPeter1997 authored Jul 16, 2024
1 parent 1e32489 commit 428071a
Show file tree
Hide file tree
Showing 29 changed files with 538 additions and 358 deletions.
1 change: 0 additions & 1 deletion src/Draco.Chr.Tests/DeduplicateTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Collections.Immutable;
using Draco.Chr.Constraints;
using Draco.Chr.Rules;
using Draco.Chr.Solve;
Expand Down
1 change: 0 additions & 1 deletion src/Draco.Chr.Tests/FibonacciTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Collections.Immutable;
using Draco.Chr.Constraints;
using Draco.Chr.Rules;
using Draco.Chr.Solve;
Expand Down
1 change: 0 additions & 1 deletion src/Draco.Chr/Rules/RuleFactory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
Expand Down
1 change: 0 additions & 1 deletion src/Draco.Chr/Solve/DefinitionOrderSolver.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Draco.Chr.Rules;
using Draco.Chr.Tracing;
Expand Down
1 change: 0 additions & 1 deletion src/Draco.Chr/Tracing/BroadcastTracer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Draco.Chr.Constraints;
Expand Down
4 changes: 1 addition & 3 deletions src/Draco.Chr/Tracing/ITracer.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using System.Collections.Generic;
using Draco.Chr.Constraints;
using System.Collections.Immutable;
using Draco.Chr.Rules;
using System.Collections;
using System.Collections.Generic;

namespace Draco.Chr.Tracing;

Expand Down
1 change: 0 additions & 1 deletion src/Draco.Chr/Tracing/NullTracer.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using Draco.Chr.Constraints;
using Draco.Chr.Rules;

Expand Down
1 change: 0 additions & 1 deletion src/Draco.Chr/Tracing/StreamTracer.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using Draco.Chr.Constraints;
using Draco.Chr.Rules;
Expand Down
3 changes: 2 additions & 1 deletion src/Draco.Compiler/Internal/Binding/Binder_Lvalue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Draco.Compiler.Internal.BoundTree;
using Draco.Compiler.Internal.Diagnostics;
using Draco.Compiler.Internal.Solver;
using Draco.Compiler.Internal.Solver.OverloadResolution;
using Draco.Compiler.Internal.Solver.Tasks;
using Draco.Compiler.Internal.Symbols;
using Draco.Compiler.Internal.Symbols.Error;
Expand Down Expand Up @@ -179,7 +180,7 @@ private async BindingTask<BoundLvalue> BindIndexLvalue(IndexExpressionSyntax syn
argsTask
.Zip(syntax.IndexList.Values)
.Select(pair => constraints.Arg(pair.Second, pair.First, diagnostics))
.Append(new ConstraintSolver.Argument(null, returnType))
.Append(new Argument(null, returnType))
.ToImmutableArray(),
// NOTE: We don't care about the return type, this is an lvalue
out _,
Expand Down
5 changes: 3 additions & 2 deletions src/Draco.Compiler/Internal/Solver/CallConstraint.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Immutable;
using Draco.Compiler.Internal.Solver.OverloadResolution;
using Draco.Compiler.Internal.Symbols;
using Draco.Compiler.Internal.Utilities;

Expand All @@ -9,7 +10,7 @@ namespace Draco.Compiler.Internal.Solver;
/// </summary>
internal sealed class CallConstraint(
TypeSymbol calledType,
ImmutableArray<ConstraintSolver.Argument> arguments,
ImmutableArray<Argument> arguments,
TypeSymbol returnType,
ConstraintLocator locator) : Constraint<Unit>(locator)
{
Expand All @@ -21,7 +22,7 @@ internal sealed class CallConstraint(
/// <summary>
/// The arguments the function was called with.
/// </summary>
public ImmutableArray<ConstraintSolver.Argument> Arguments { get; } = arguments;
public ImmutableArray<Argument> Arguments { get; } = arguments;

/// <summary>
/// The return type of the call.
Expand Down
4 changes: 3 additions & 1 deletion src/Draco.Compiler/Internal/Solver/ConstraintLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ internal abstract class ConstraintLocator
/// </summary>
/// <param name="syntax">The syntax node to connect the location to.</param>
/// <returns>The locator that will point at the syntax.</returns>
public static ConstraintLocator Syntax(SyntaxNode syntax) => new SyntaxConstraintLocator(syntax);
public static ConstraintLocator Syntax(SyntaxNode? syntax) => syntax is null
? Null
: new SyntaxConstraintLocator(syntax);

/// <summary>
/// Creates a constraint locator based on anonter constraint.
Expand Down
8 changes: 1 addition & 7 deletions src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Draco.Compiler.Internal.Binding.Tasks;
using Draco.Compiler.Internal.BoundTree;
using Draco.Compiler.Internal.Diagnostics;
using Draco.Compiler.Internal.Solver.OverloadResolution;
using Draco.Compiler.Internal.Solver.Tasks;
using Draco.Compiler.Internal.Symbols;
using Draco.Compiler.Internal.Symbols.Error;
Expand All @@ -19,13 +20,6 @@ namespace Draco.Compiler.Internal.Solver;
/// </summary>
internal sealed partial class ConstraintSolver(SyntaxNode context, string contextName)
{
/// <summary>
/// Represents an argument for a call.
/// </summary>
/// <param name="Syntax">The syntax of the argument, if any.</param>
/// <param name="Type">The type of the argument.</param>
public readonly record struct Argument(SyntaxNode? Syntax, TypeSymbol Type);

/// <summary>
/// The context being inferred.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Draco.Compiler.Api.Syntax;
using Draco.Compiler.Internal.Solver.OverloadResolution;
using Draco.Compiler.Internal.Solver.Tasks;
using Draco.Compiler.Internal.Symbols;
using Draco.Compiler.Internal.Utilities;
Expand Down Expand Up @@ -192,7 +193,8 @@ public SolverTask<FunctionSymbol> Overload(
SyntaxNode syntax)
{
returnType = this.AllocateTypeVariable();
var constraint = new OverloadConstraint(name, functions, args, returnType, ConstraintLocator.Syntax(syntax));
var candidateSet = OverloadCandidateSet.Create(functions, args);
var constraint = new OverloadConstraint(name, candidateSet, returnType, ConstraintLocator.Syntax(syntax));
this.Add(constraint);
return constraint.CompletionSource.Task;
}
Expand Down
64 changes: 25 additions & 39 deletions src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq;
using Draco.Compiler.Internal.Binding;
using Draco.Compiler.Internal.Diagnostics;
using Draco.Compiler.Internal.Solver.OverloadResolution;
using Draco.Compiler.Internal.Symbols;
using Draco.Compiler.Internal.Symbols.Error;
using Draco.Compiler.Internal.Symbols.Synthetized;
Expand Down Expand Up @@ -238,31 +239,18 @@ private void HandleRule(MemberConstraint constraint, DiagnosticBag? diagnostics)
private void HandleRule(OverloadConstraint constraint, DiagnosticBag? diagnostics)
{
var functionName = constraint.Name;
var functionsWithMatchingArgc = constraint.Candidates
.Where(f => MatchesParameterCount(f, constraint.Arguments.Length))
.ToList();
var maxArgc = functionsWithMatchingArgc
.Select(f => f.Parameters.Length)
.Append(0)
.Max();
var candidates = functionsWithMatchingArgc
.Select(f => new OverloadCandidate(f, new(maxArgc)))
.ToList();

while (true)
{
var changed = this.RefineOverloadScores(candidates, constraint.Arguments, out var wellDefined);
if (wellDefined) break;
if (candidates.Count <= 1) break;
if (!changed) return;
}

var candidateSet = constraint.CandidateSet;
candidateSet.Refine();
// If it's not well-defined, we can't advance
if (!candidateSet.IsWellDefined) return;

// We have all candidates well-defined, find the absolute dominator
if (candidates.Count == 0)
if (candidateSet.Count == 0)
{
UnifyAsserted(constraint.ReturnType, WellKnownTypes.ErrorType);
// Best-effort shape approximation
var errorSymbol = new NoOverloadFunctionSymbol(constraint.Arguments.Length);
var errorSymbol = new NoOverloadFunctionSymbol(candidateSet.Arguments.Length);
constraint.ReportDiagnostic(diagnostics, diag => diag
.WithTemplate(TypeCheckingErrors.NoMatchingOverload)
.WithFormatArgs(functionName));
Expand All @@ -271,11 +259,13 @@ private void HandleRule(OverloadConstraint constraint, DiagnosticBag? diagnostic
}

// We have one or more, find the max dominator
var dominatingCandidates = GetDominatingCandidates(candidates);
var dominatingCandidates = CallScore
.FindDominatorsBy(candidateSet, c => c.Score)
.ToImmutableArray();
if (dominatingCandidates.Length == 1)
{
// Resolved fine, choose the symbol, which might generic-instantiate it
var chosen = this.ChooseSymbol(dominatingCandidates[0]);
var chosen = this.ChooseSymbol(dominatingCandidates[0].Data);

// Inference
if (chosen.IsVariadic)
Expand Down Expand Up @@ -367,24 +357,20 @@ private void HandleRule(CallConstraint constraint, DiagnosticBag? diagnostics)
}

// Start scoring args
var score = new CallScore(functionType.Parameters.Length);
while (true)
var candidate = CallCandidate.Create(functionType);
candidate.Refine(constraint.Arguments);

if (candidate.IsEliminated)
{
var changed = this.AdjustScore(functionType, constraint.Arguments, score);
if (score.HasZero)
{
// Error
UnifyAsserted(constraint.ReturnType, WellKnownTypes.ErrorType);
constraint.ReportDiagnostic(diagnostics, diag => diag
.WithTemplate(TypeCheckingErrors.TypeMismatch)
.WithFormatArgs(
functionType,
this.MakeMismatchedFunctionType(constraint.Arguments, functionType.ReturnType)));
constraint.CompletionSource.SetResult(default);
return;
}
if (score.IsWellDefined) break;
if (!changed) return;
// Error
UnifyAsserted(constraint.ReturnType, WellKnownTypes.ErrorType);
constraint.ReportDiagnostic(diagnostics, diag => diag
.WithTemplate(TypeCheckingErrors.TypeMismatch)
.WithFormatArgs(
functionType,
this.MakeMismatchedFunctionType(constraint.Arguments, functionType.ReturnType)));
constraint.CompletionSource.SetResult(default);
return;
}

// We are done
Expand Down
Loading

0 comments on commit 428071a

Please sign in to comment.