From 428071a27af396d4fec0d588672f3e6fb09bc42d Mon Sep 17 00:00:00 2001 From: LPeter1997 Date: Tue, 16 Jul 2024 21:25:20 +0200 Subject: [PATCH] Refactor overloading (#413) * Factored out call scoring utilities * Create CallUtilities.cs * Locator simplification * Removed crap that got factored out * Fixed compilation * Update Rename.cs * Lots of shuffling for overload sets * Merged call logics * Code cleanup * Update Draco.ProjectSystem.csproj * Update Draco.SourceGeneration.csproj --- src/Draco.Chr.Tests/DeduplicateTest.cs | 1 - src/Draco.Chr.Tests/FibonacciTest.cs | 1 - src/Draco.Chr/Rules/RuleFactory.cs | 1 - src/Draco.Chr/Solve/DefinitionOrderSolver.cs | 1 - src/Draco.Chr/Tracing/BroadcastTracer.cs | 1 - src/Draco.Chr/Tracing/ITracer.cs | 4 +- src/Draco.Chr/Tracing/NullTracer.cs | 1 - src/Draco.Chr/Tracing/StreamTracer.cs | 1 - .../Internal/Binding/Binder_Lvalue.cs | 3 +- .../Internal/Solver/CallConstraint.cs | 5 +- .../Internal/Solver/ConstraintLocator.cs | 4 +- .../Internal/Solver/ConstraintSolver.cs | 8 +- .../Solver/ConstraintSolver_Constraints.cs | 4 +- .../Internal/Solver/ConstraintSolver_Rules.cs | 64 ++--- .../Internal/Solver/ConstraintSolver_Utils.cs | 221 +----------------- .../Internal/Solver/OverloadConstraint.cs | 17 +- .../Solver/OverloadResolution/Argument.cs | 11 + .../OverloadResolution/ArgumentScore.cs | 116 +++++++++ .../OverloadResolution/CallCandidate.cs | 118 ++++++++++ .../{ => OverloadResolution}/CallScore.cs | 135 +++++------ .../OverloadResolution/CallScoreComparison.cs | 32 +++ .../OverloadResolution/CallUtilities.cs | 31 +++ .../OverloadCandidateSet.cs | 106 +++++++++ .../Capabilities/Rename.cs | 2 +- .../Capabilities/TextDocumentSync.cs | 1 - src/Draco.ProjectSystem/DesignTimeBuild.cs | 1 - .../Draco.ProjectSystem.csproj | 2 +- src/Draco.ProjectSystem/Project.cs | 2 - .../Draco.SourceGeneration.csproj | 2 +- 29 files changed, 538 insertions(+), 358 deletions(-) create mode 100644 src/Draco.Compiler/Internal/Solver/OverloadResolution/Argument.cs create mode 100644 src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs create mode 100644 src/Draco.Compiler/Internal/Solver/OverloadResolution/CallCandidate.cs rename src/Draco.Compiler/Internal/Solver/{ => OverloadResolution}/CallScore.cs (54%) create mode 100644 src/Draco.Compiler/Internal/Solver/OverloadResolution/CallScoreComparison.cs create mode 100644 src/Draco.Compiler/Internal/Solver/OverloadResolution/CallUtilities.cs create mode 100644 src/Draco.Compiler/Internal/Solver/OverloadResolution/OverloadCandidateSet.cs diff --git a/src/Draco.Chr.Tests/DeduplicateTest.cs b/src/Draco.Chr.Tests/DeduplicateTest.cs index 072ab276b..369e27fe7 100644 --- a/src/Draco.Chr.Tests/DeduplicateTest.cs +++ b/src/Draco.Chr.Tests/DeduplicateTest.cs @@ -1,4 +1,3 @@ -using System.Collections.Immutable; using Draco.Chr.Constraints; using Draco.Chr.Rules; using Draco.Chr.Solve; diff --git a/src/Draco.Chr.Tests/FibonacciTest.cs b/src/Draco.Chr.Tests/FibonacciTest.cs index 018641d54..2e64484ad 100644 --- a/src/Draco.Chr.Tests/FibonacciTest.cs +++ b/src/Draco.Chr.Tests/FibonacciTest.cs @@ -1,4 +1,3 @@ -using System.Collections.Immutable; using Draco.Chr.Constraints; using Draco.Chr.Rules; using Draco.Chr.Solve; diff --git a/src/Draco.Chr/Rules/RuleFactory.cs b/src/Draco.Chr/Rules/RuleFactory.cs index 69d58c8c6..55087b7f3 100644 --- a/src/Draco.Chr/Rules/RuleFactory.cs +++ b/src/Draco.Chr/Rules/RuleFactory.cs @@ -1,5 +1,4 @@ using System; -using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; diff --git a/src/Draco.Chr/Solve/DefinitionOrderSolver.cs b/src/Draco.Chr/Solve/DefinitionOrderSolver.cs index 65d7703be..367cd7b59 100644 --- a/src/Draco.Chr/Solve/DefinitionOrderSolver.cs +++ b/src/Draco.Chr/Solve/DefinitionOrderSolver.cs @@ -1,6 +1,5 @@ using System.Collections; using System.Collections.Generic; -using System.Collections.Immutable; using System.Linq; using Draco.Chr.Rules; using Draco.Chr.Tracing; diff --git a/src/Draco.Chr/Tracing/BroadcastTracer.cs b/src/Draco.Chr/Tracing/BroadcastTracer.cs index 7a06fef5e..ec3a8a8b9 100644 --- a/src/Draco.Chr/Tracing/BroadcastTracer.cs +++ b/src/Draco.Chr/Tracing/BroadcastTracer.cs @@ -1,4 +1,3 @@ -using System; using System.Collections.Generic; using System.Linq; using Draco.Chr.Constraints; diff --git a/src/Draco.Chr/Tracing/ITracer.cs b/src/Draco.Chr/Tracing/ITracer.cs index fa9a2b36d..bdd212f57 100644 --- a/src/Draco.Chr/Tracing/ITracer.cs +++ b/src/Draco.Chr/Tracing/ITracer.cs @@ -1,8 +1,6 @@ +using System.Collections.Generic; using Draco.Chr.Constraints; -using System.Collections.Immutable; using Draco.Chr.Rules; -using System.Collections; -using System.Collections.Generic; namespace Draco.Chr.Tracing; diff --git a/src/Draco.Chr/Tracing/NullTracer.cs b/src/Draco.Chr/Tracing/NullTracer.cs index 96a924c21..d52fa3ddc 100644 --- a/src/Draco.Chr/Tracing/NullTracer.cs +++ b/src/Draco.Chr/Tracing/NullTracer.cs @@ -1,5 +1,4 @@ using System.Collections.Generic; -using System.Collections.Immutable; using Draco.Chr.Constraints; using Draco.Chr.Rules; diff --git a/src/Draco.Chr/Tracing/StreamTracer.cs b/src/Draco.Chr/Tracing/StreamTracer.cs index c8a520dd9..dc20f8c72 100644 --- a/src/Draco.Chr/Tracing/StreamTracer.cs +++ b/src/Draco.Chr/Tracing/StreamTracer.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.Collections.Immutable; using System.IO; using Draco.Chr.Constraints; using Draco.Chr.Rules; diff --git a/src/Draco.Compiler/Internal/Binding/Binder_Lvalue.cs b/src/Draco.Compiler/Internal/Binding/Binder_Lvalue.cs index bd2437ef8..7a1790982 100644 --- a/src/Draco.Compiler/Internal/Binding/Binder_Lvalue.cs +++ b/src/Draco.Compiler/Internal/Binding/Binder_Lvalue.cs @@ -7,6 +7,7 @@ using Draco.Compiler.Internal.BoundTree; using Draco.Compiler.Internal.Diagnostics; using Draco.Compiler.Internal.Solver; +using Draco.Compiler.Internal.Solver.OverloadResolution; using Draco.Compiler.Internal.Solver.Tasks; using Draco.Compiler.Internal.Symbols; using Draco.Compiler.Internal.Symbols.Error; @@ -179,7 +180,7 @@ private async BindingTask BindIndexLvalue(IndexExpressionSyntax syn argsTask .Zip(syntax.IndexList.Values) .Select(pair => constraints.Arg(pair.Second, pair.First, diagnostics)) - .Append(new ConstraintSolver.Argument(null, returnType)) + .Append(new Argument(null, returnType)) .ToImmutableArray(), // NOTE: We don't care about the return type, this is an lvalue out _, diff --git a/src/Draco.Compiler/Internal/Solver/CallConstraint.cs b/src/Draco.Compiler/Internal/Solver/CallConstraint.cs index 5ada8880c..152ce4050 100644 --- a/src/Draco.Compiler/Internal/Solver/CallConstraint.cs +++ b/src/Draco.Compiler/Internal/Solver/CallConstraint.cs @@ -1,4 +1,5 @@ using System.Collections.Immutable; +using Draco.Compiler.Internal.Solver.OverloadResolution; using Draco.Compiler.Internal.Symbols; using Draco.Compiler.Internal.Utilities; @@ -9,7 +10,7 @@ namespace Draco.Compiler.Internal.Solver; /// internal sealed class CallConstraint( TypeSymbol calledType, - ImmutableArray arguments, + ImmutableArray arguments, TypeSymbol returnType, ConstraintLocator locator) : Constraint(locator) { @@ -21,7 +22,7 @@ internal sealed class CallConstraint( /// /// The arguments the function was called with. /// - public ImmutableArray Arguments { get; } = arguments; + public ImmutableArray Arguments { get; } = arguments; /// /// The return type of the call. diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintLocator.cs b/src/Draco.Compiler/Internal/Solver/ConstraintLocator.cs index c1dd9b083..bc07ab447 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintLocator.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintLocator.cs @@ -18,7 +18,9 @@ internal abstract class ConstraintLocator /// /// The syntax node to connect the location to. /// The locator that will point at the syntax. - public static ConstraintLocator Syntax(SyntaxNode syntax) => new SyntaxConstraintLocator(syntax); + public static ConstraintLocator Syntax(SyntaxNode? syntax) => syntax is null + ? Null + : new SyntaxConstraintLocator(syntax); /// /// Creates a constraint locator based on anonter constraint. diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs index dfac84016..407eafc4f 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs @@ -7,6 +7,7 @@ using Draco.Compiler.Internal.Binding.Tasks; using Draco.Compiler.Internal.BoundTree; using Draco.Compiler.Internal.Diagnostics; +using Draco.Compiler.Internal.Solver.OverloadResolution; using Draco.Compiler.Internal.Solver.Tasks; using Draco.Compiler.Internal.Symbols; using Draco.Compiler.Internal.Symbols.Error; @@ -19,13 +20,6 @@ namespace Draco.Compiler.Internal.Solver; /// internal sealed partial class ConstraintSolver(SyntaxNode context, string contextName) { - /// - /// Represents an argument for a call. - /// - /// The syntax of the argument, if any. - /// The type of the argument. - public readonly record struct Argument(SyntaxNode? Syntax, TypeSymbol Type); - /// /// The context being inferred. /// diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs index e9fad5105..99b644fb7 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using Draco.Compiler.Api.Syntax; +using Draco.Compiler.Internal.Solver.OverloadResolution; using Draco.Compiler.Internal.Solver.Tasks; using Draco.Compiler.Internal.Symbols; using Draco.Compiler.Internal.Utilities; @@ -192,7 +193,8 @@ public SolverTask Overload( SyntaxNode syntax) { returnType = this.AllocateTypeVariable(); - var constraint = new OverloadConstraint(name, functions, args, returnType, ConstraintLocator.Syntax(syntax)); + var candidateSet = OverloadCandidateSet.Create(functions, args); + var constraint = new OverloadConstraint(name, candidateSet, returnType, ConstraintLocator.Syntax(syntax)); this.Add(constraint); return constraint.CompletionSource.Task; } diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs index 656ac37b5..154856aca 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs @@ -4,6 +4,7 @@ using System.Linq; using Draco.Compiler.Internal.Binding; using Draco.Compiler.Internal.Diagnostics; +using Draco.Compiler.Internal.Solver.OverloadResolution; using Draco.Compiler.Internal.Symbols; using Draco.Compiler.Internal.Symbols.Error; using Draco.Compiler.Internal.Symbols.Synthetized; @@ -238,31 +239,18 @@ private void HandleRule(MemberConstraint constraint, DiagnosticBag? diagnostics) private void HandleRule(OverloadConstraint constraint, DiagnosticBag? diagnostics) { var functionName = constraint.Name; - var functionsWithMatchingArgc = constraint.Candidates - .Where(f => MatchesParameterCount(f, constraint.Arguments.Length)) - .ToList(); - var maxArgc = functionsWithMatchingArgc - .Select(f => f.Parameters.Length) - .Append(0) - .Max(); - var candidates = functionsWithMatchingArgc - .Select(f => new OverloadCandidate(f, new(maxArgc))) - .ToList(); - - while (true) - { - var changed = this.RefineOverloadScores(candidates, constraint.Arguments, out var wellDefined); - if (wellDefined) break; - if (candidates.Count <= 1) break; - if (!changed) return; - } + + var candidateSet = constraint.CandidateSet; + candidateSet.Refine(); + // If it's not well-defined, we can't advance + if (!candidateSet.IsWellDefined) return; // We have all candidates well-defined, find the absolute dominator - if (candidates.Count == 0) + if (candidateSet.Count == 0) { UnifyAsserted(constraint.ReturnType, WellKnownTypes.ErrorType); // Best-effort shape approximation - var errorSymbol = new NoOverloadFunctionSymbol(constraint.Arguments.Length); + var errorSymbol = new NoOverloadFunctionSymbol(candidateSet.Arguments.Length); constraint.ReportDiagnostic(diagnostics, diag => diag .WithTemplate(TypeCheckingErrors.NoMatchingOverload) .WithFormatArgs(functionName)); @@ -271,11 +259,13 @@ private void HandleRule(OverloadConstraint constraint, DiagnosticBag? diagnostic } // We have one or more, find the max dominator - var dominatingCandidates = GetDominatingCandidates(candidates); + var dominatingCandidates = CallScore + .FindDominatorsBy(candidateSet, c => c.Score) + .ToImmutableArray(); if (dominatingCandidates.Length == 1) { // Resolved fine, choose the symbol, which might generic-instantiate it - var chosen = this.ChooseSymbol(dominatingCandidates[0]); + var chosen = this.ChooseSymbol(dominatingCandidates[0].Data); // Inference if (chosen.IsVariadic) @@ -367,24 +357,20 @@ private void HandleRule(CallConstraint constraint, DiagnosticBag? diagnostics) } // Start scoring args - var score = new CallScore(functionType.Parameters.Length); - while (true) + var candidate = CallCandidate.Create(functionType); + candidate.Refine(constraint.Arguments); + + if (candidate.IsEliminated) { - var changed = this.AdjustScore(functionType, constraint.Arguments, score); - if (score.HasZero) - { - // Error - UnifyAsserted(constraint.ReturnType, WellKnownTypes.ErrorType); - constraint.ReportDiagnostic(diagnostics, diag => diag - .WithTemplate(TypeCheckingErrors.TypeMismatch) - .WithFormatArgs( - functionType, - this.MakeMismatchedFunctionType(constraint.Arguments, functionType.ReturnType))); - constraint.CompletionSource.SetResult(default); - return; - } - if (score.IsWellDefined) break; - if (!changed) return; + // Error + UnifyAsserted(constraint.ReturnType, WellKnownTypes.ErrorType); + constraint.ReportDiagnostic(diagnostics, diag => diag + .WithTemplate(TypeCheckingErrors.TypeMismatch) + .WithFormatArgs( + functionType, + this.MakeMismatchedFunctionType(constraint.Arguments, functionType.ReturnType))); + constraint.CompletionSource.SetResult(default); + return; } // We are done diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Utils.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Utils.cs index d4216908c..fda73e320 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Utils.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Utils.cs @@ -1,9 +1,6 @@ -using System; -using System.Collections.Generic; using System.Collections.Immutable; -using System.Diagnostics; using System.Linq; -using Draco.Compiler.Internal.Binding; +using Draco.Compiler.Internal.Solver.OverloadResolution; using Draco.Compiler.Internal.Symbols; using Draco.Compiler.Internal.Symbols.Synthetized; @@ -11,8 +8,6 @@ namespace Draco.Compiler.Internal.Solver; internal sealed partial class ConstraintSolver { - private readonly record struct OverloadCandidate(FunctionSymbol Symbol, CallScore Score); - private FunctionTypeSymbol MakeMismatchedFunctionType(ImmutableArray args, TypeSymbol returnType) => new( args // TODO: We are passing null here... @@ -21,26 +16,6 @@ internal sealed partial class ConstraintSolver .ToImmutableArray(), returnType); - private static ImmutableArray GetDominatingCandidates(IReadOnlyList candidates) - { - // For a single candidate, don't bother - if (candidates.Count == 1) return [candidates[0].Symbol]; - - // We have more than one, find the max dominator - // NOTE: This might not be the actual dominator in case of mutual non-dominance - var bestScore = CallScore.FindBest(candidates.Select(c => c.Score)); - // We keep every candidate that dominates this score, or there is mutual non-dominance - var dominatingCandidates = candidates - .Where(pair => bestScore is null - || CallScore.Compare(bestScore.Value, pair.Score) - is CallScoreComparison.Equal - or CallScoreComparison.NoDominance) - .Select(pair => pair.Symbol) - .ToImmutableArray(); - Debug.Assert(dominatingCandidates.Length > 0); - return dominatingCandidates; - } - private FunctionSymbol ChooseSymbol(FunctionSymbol chosen) { // Nongeneric, just return @@ -62,197 +37,5 @@ private FunctionSymbol ChooseSymbol(FunctionSymbol chosen) private void UnifyParameterWithArgument(TypeSymbol paramType, Argument argument) => _ = this.Assignable( paramType, argument.Type, - argument.Syntax is null ? ConstraintLocator.Null : ConstraintLocator.Syntax(argument.Syntax)); - - private static bool MatchesParameterCount(FunctionSymbol function, int argc) - { - // Exact count match is always eligibe by only param count - if (function.Parameters.Length == argc) return true; - // If not variadic, we do need an exact match - if (!function.IsVariadic) return false; - // Otherise, there must be one less, exactly as many, or more arguments - // - one less means nullary variadics - // - exact match is one variadic - // - more is more variadics - if (argc + 1 >= function.Parameters.Length) return true; - // No match - return false; - } - - private bool RefineOverloadScores( - List candidates, - ImmutableArray arguments, - out bool wellDefined) - { - var changed = false; - wellDefined = true; - // Iterate through all candidates - for (var i = 0; i < candidates.Count;) - { - var candidate = candidates[i]; - - // Compute any undefined arguments - changed = this.AdjustOverloadScore(candidate, arguments) || changed; - // We consider having a 0-element well-defined, since we are throwing it away - var hasZero = candidate.Score.HasZero; - wellDefined = wellDefined && (candidate.Score.IsWellDefined || hasZero); - - // If any of the score vector components reached 0, we exclude the candidate - if (hasZero) - { - candidates.RemoveAt(i); - } - else - { - // Otherwise it stays - ++i; - } - } - return changed; - } - - private bool AdjustScore(FunctionTypeSymbol candidate, ImmutableArray args, CallScore scoreVector) - { - Debug.Assert(candidate.Parameters.Length == args.Length); - Debug.Assert(candidate.Parameters.Length == scoreVector.Length); - - var changed = false; - for (var i = 0; i < scoreVector.Length; ++i) - { - var param = candidate.Parameters[i]; - var arg = args[i]; - var score = scoreVector[i]; - - // If the argument is not null, it means we have already scored it - if (score is not null) continue; - - score = ScoreArgument(param, arg.Type); - changed = changed || score is not null; - scoreVector[i] = score; - - // If the score hit 0, terminate early, this overload got eliminated - if (score == 0) return changed; - } - return changed; - } - - private bool AdjustOverloadScore(OverloadCandidate candidate, ImmutableArray arguments) - { - var changed = false; - var (func, scoreVector) = candidate; - - for (var i = 0; i < scoreVector.Length; ++i) - { - var param = func.Parameters[i]; - // Handle that separately - if (param.IsVariadic) continue; - - if (arguments.Length == i) - { - // Special case, this call was extended because of variadics - if (scoreVector[i] is null) - { - scoreVector[i] = FullScore; - changed = true; - } - continue; - } - - var argType = arguments[i].Type; - var score = scoreVector[i]; - - // If the argument is not null, it means we have already scored it - if (score is not null) continue; - - score = ScoreArgument(param, argType); - changed = changed || score is not null; - scoreVector[i] = score; - - // If the score hit 0, terminate early, this overload got eliminated - if (score == 0) return changed; - } - // Handle variadic arguments - if (func.IsVariadic && scoreVector[^1] is null) - { - var variadicParam = func.Parameters[^1]; - var variadicArgTypes = arguments - .Skip(func.Parameters.Length - 1) - .Select(a => a.Type); - var score = ScoreVariadicArguments(variadicParam, variadicArgTypes); - changed = changed || score is not null; - scoreVector[^1] = score; - } - return changed; - } - - /// - /// Scores a sequence of variadic function call argument. - /// - /// The variadic function parameter. - /// The passed in argument types. - /// The score of the match. - private static int? ScoreVariadicArguments(ParameterSymbol param, IEnumerable argTypes) - { - if (!param.IsVariadic) throw new ArgumentException("the provided parameter is not variadic", nameof(param)); - if (!BinderFacts.TryGetVariadicElementType(param.Type, out var elementType)) return 0; - - return argTypes - .Select(argType => ScoreArgument(elementType, argType)) - .Append(FullScore) - .Select(s => s / 2) - .Min(); - } - - /// - /// Scores a function call argument. - /// - /// The function parameter. - /// The passed in argument type. - /// The score of the match. - private static int? ScoreArgument(ParameterSymbol param, TypeSymbol argType) - { - if (param.IsVariadic) throw new ArgumentException("the provided parameter variadic", nameof(param)); - return ScoreArgument(param.Type, argType); - } - - private const int FullScore = 16; - private const int HalfScore = 8; - private const int ZeroScore = 0; - - private static int? ScoreArgument(TypeSymbol paramType, TypeSymbol argType) - { - paramType = paramType.Substitution; - argType = argType.Substitution; - - // If either are still not ground types, we can't decide - if (!paramType.IsGroundType || !argType.IsGroundType) return null; - - // Exact equality is max score - if (SymbolEqualityComparer.Default.Equals(paramType, argType)) return FullScore; - - // Base type match is half score - if (SymbolEqualityComparer.Default.IsBaseOf(paramType, argType)) return HalfScore; - - // TODO: Unspecified what happens for generics - // For now we require an exact match and score is the lowest score among generic args - if (paramType.IsGenericInstance && argType.IsGenericInstance) - { - var paramGenericDefinition = paramType.GenericDefinition!; - var argGenericDefinition = argType.GenericDefinition!; - - if (!SymbolEqualityComparer.Default.Equals(paramGenericDefinition, argGenericDefinition)) return ZeroScore; - - Debug.Assert(paramType.GenericArguments.Length == argType.GenericArguments.Length); - return paramType.GenericArguments - .Zip(argType.GenericArguments) - .Select(pair => ScoreArgument(pair.First, pair.Second)) - .Min(); - } - - // Type parameter match is half score - if (paramType is TypeParameterSymbol) return HalfScore; - - // Otherwise, no match - return ZeroScore; - } + ConstraintLocator.Syntax(argument.Syntax)); } diff --git a/src/Draco.Compiler/Internal/Solver/OverloadConstraint.cs b/src/Draco.Compiler/Internal/Solver/OverloadConstraint.cs index 44f13a0f1..b5264a5ef 100644 --- a/src/Draco.Compiler/Internal/Solver/OverloadConstraint.cs +++ b/src/Draco.Compiler/Internal/Solver/OverloadConstraint.cs @@ -1,4 +1,7 @@ +using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; +using Draco.Compiler.Internal.Solver.OverloadResolution; using Draco.Compiler.Internal.Symbols; namespace Draco.Compiler.Internal.Solver; @@ -8,27 +11,29 @@ namespace Draco.Compiler.Internal.Solver; /// internal sealed class OverloadConstraint( string name, - ImmutableArray candidates, - ImmutableArray arguments, + OverloadCandidateSet candidateSet, TypeSymbol returnType, ConstraintLocator locator) : Constraint(locator) { - private readonly record struct Candidate(FunctionSymbol Symbol, CallScore Score); - /// /// The function name. /// public string Name { get; } = name; + /// + /// The set of candidates. + /// + public OverloadCandidateSet CandidateSet { get; } = candidateSet; + /// /// The candidate functions to search among. /// - public ImmutableArray Candidates { get; } = candidates; + public IEnumerable Candidates => this.CandidateSet.Select(c => c.Data); /// /// The arguments the function was called with. /// - public ImmutableArray Arguments { get; } = arguments; + public ImmutableArray Arguments => this.CandidateSet.Arguments; /// /// The return type of the call. diff --git a/src/Draco.Compiler/Internal/Solver/OverloadResolution/Argument.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/Argument.cs new file mode 100644 index 000000000..859b19d33 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/Argument.cs @@ -0,0 +1,11 @@ +using Draco.Compiler.Api.Syntax; +using Draco.Compiler.Internal.Symbols; + +namespace Draco.Compiler.Internal.Solver.OverloadResolution; + +/// +/// Represents an argument to a function. +/// +/// The syntax of the argument. +/// The type of the argument. +internal readonly record struct Argument(SyntaxNode? Syntax, TypeSymbol Type); diff --git a/src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs new file mode 100644 index 000000000..c7943e07e --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs @@ -0,0 +1,116 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Draco.Compiler.Internal.Binding; +using Draco.Compiler.Internal.Symbols; + +namespace Draco.Compiler.Internal.Solver.OverloadResolution; + +/// +/// Utility for argument scoring. +/// +internal static class ArgumentScore +{ + /// + /// An undefined score. + /// + public const int Undefined = -1; + + /// + /// Maximum score for a full match. + /// + public const int FullScore = 16; + + private const int HalfScore = 8; + private const int ZeroScore = 0; + + /// + /// Scores a sequence of variadic function call argument. + /// + /// The variadic function parameter. + /// The passed in arguments. + /// The score of the match. + public static int ScoreVariadicArguments(ParameterSymbol param, IEnumerable args) => + ScoreVariadicArguments(param, args.Select(arg => arg.Type)); + + /// + /// Scores a sequence of variadic function call argument. + /// + /// The variadic function parameter. + /// The passed in argument types. + /// The score of the match. + public static int ScoreVariadicArguments(ParameterSymbol param, IEnumerable argTypes) + { + if (!param.IsVariadic) throw new ArgumentException("the provided parameter is not variadic", nameof(param)); + if (!BinderFacts.TryGetVariadicElementType(param.Type, out var elementType)) return 0; + + return argTypes + // Score each argument + .Select(argType => ScoreArgument(elementType, argType)) + // In case the sequence is empty, we assume a full score match + .Append(FullScore) + // Every variadic argument is half as important as a normal argument + .Select(s => s / 2) + // Take the lowest score + .Min(); + } + + /// + /// Scores a function call argument. + /// + /// The function parameter. + /// The passed in argument. + /// The score of the match. + public static int ScoreArgument(ParameterSymbol param, Argument arg) => + ScoreArgument(param, arg.Type); + + /// + /// Scores a function call argument. + /// + /// The function parameter. + /// The passed in argument type. + /// The score of the match. + public static int ScoreArgument(ParameterSymbol param, TypeSymbol argType) + { + if (param.IsVariadic) throw new ArgumentException("the provided parameter variadic", nameof(param)); + return ScoreArgument(param.Type, argType); + } + + private static int ScoreArgument(TypeSymbol paramType, TypeSymbol argType) + { + paramType = paramType.Substitution; + argType = argType.Substitution; + + // If either are still not ground types, we can't decide + if (!paramType.IsGroundType || !argType.IsGroundType) return Undefined; + + // Exact equality is max score + if (SymbolEqualityComparer.Default.Equals(paramType, argType)) return FullScore; + + // Base type match is half score + if (SymbolEqualityComparer.Default.IsBaseOf(paramType, argType)) return HalfScore; + + // TODO: Unspecified what happens for generics + // For now we require an exact match and score is the lowest score among generic args + if (paramType.IsGenericInstance && argType.IsGenericInstance) + { + var paramGenericDefinition = paramType.GenericDefinition!; + var argGenericDefinition = argType.GenericDefinition!; + + if (!SymbolEqualityComparer.Default.Equals(paramGenericDefinition, argGenericDefinition)) return ZeroScore; + + Debug.Assert(paramType.GenericArguments.Length == argType.GenericArguments.Length); + return paramType.GenericArguments + .Zip(argType.GenericArguments) + .Select(pair => ScoreArgument(pair.First, pair.Second)) + .Min(); + } + + // Type parameter match is half score + if (paramType is TypeParameterSymbol) return HalfScore; + + // Otherwise, no match + return ZeroScore; + } +} diff --git a/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallCandidate.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallCandidate.cs new file mode 100644 index 000000000..3457bdfb3 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallCandidate.cs @@ -0,0 +1,118 @@ +using System.Collections.Generic; +using System.Linq; +using Draco.Compiler.Internal.Symbols; +using Draco.Compiler.Internal.Utilities; + +namespace Draco.Compiler.Internal.Solver.OverloadResolution; + +/// +/// Factory methods for . +/// +internal static class CallCandidate +{ + public static CallCandidate Create(FunctionSymbol function) => + CallCandidate.Create(function); + + public static CallCandidate Create(FunctionTypeSymbol functionType) => + CallCandidate.Create(functionType); +} + +/// +/// Represents a single candidate for overload resolution. +/// +/// Additional data type. +internal readonly struct CallCandidate +{ + public static CallCandidate Create(FunctionSymbol function) => + new(function.Parameters, function.IsVariadic, function); + + // TODO: Can a function type be variadic? This is probably something we should specify... + public static CallCandidate Create(FunctionTypeSymbol functionType) => + new(functionType.Parameters, false, default); + + /// + /// The score of the candidate. + /// + public CallScore Score { get; } + + /// + /// True, if the candidate is eliminated. + /// + public bool IsEliminated => this.Score.HasZero; + + /// + /// True, if the candidate is well defined. + /// + public bool IsWellDefined => this.Score.IsWellDefined; + + /// + /// Additional data associated with the candidate. + /// + public TData Data { get; } + + private readonly IReadOnlyList parameters; + private readonly bool isVariadic; + + private CallCandidate( + IReadOnlyList parameters, + bool isVariadic, + TData data) + { + this.parameters = parameters; + this.isVariadic = isVariadic; + this.Score = new(parameters.Count); + this.Data = data; + } + + /// + /// Refines the candidate by scoring the arguments. + /// + /// The arguments to use for the refinement. + /// True, if the score got changed in any way. + public bool Refine(IReadOnlyList arguments) + { + var changed = false; + var scoreVector = this.Score; + + for (var i = 0; i < scoreVector.Length; ++i) + { + var param = this.parameters[i]; + // Handle that separately + if (param.IsVariadic) continue; + + if (arguments.Count == i) + { + // Special case, this call was extended because of variadics + if (scoreVector[i] == ArgumentScore.Undefined) + { + scoreVector[i] = ArgumentScore.FullScore; + changed = true; + } + continue; + } + + var argType = arguments[i].Type; + var score = scoreVector[i]; + + // If the argument is not null, it means we have already scored it + if (score != ArgumentScore.Undefined) continue; + + score = ArgumentScore.ScoreArgument(param, argType); + changed = changed || score != ArgumentScore.Undefined; + scoreVector[i] = score; + + // If the score hit 0, terminate early, this candidate got eliminated + if (score == 0) return changed; + } + // Handle variadic arguments + if (this.isVariadic && scoreVector[^1] == ArgumentScore.Undefined) + { + var variadicParam = this.parameters[^1]; + var variadicArgs = arguments.Skip(this.parameters.Count - 1); + var score = ArgumentScore.ScoreVariadicArguments(variadicParam, variadicArgs); + changed = changed || score != ArgumentScore.Undefined; + scoreVector[^1] = score; + } + return changed; + } +} diff --git a/src/Draco.Compiler/Internal/Solver/CallScore.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallScore.cs similarity index 54% rename from src/Draco.Compiler/Internal/Solver/CallScore.cs rename to src/Draco.Compiler/Internal/Solver/OverloadResolution/CallScore.cs index 5a5f7d81e..3cac53948 100644 --- a/src/Draco.Compiler/Internal/Solver/CallScore.cs +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallScore.cs @@ -1,77 +1,20 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; -namespace Draco.Compiler.Internal.Solver; - -/// -/// Represents the comparison result of two s. -/// -internal enum CallScoreComparison -{ - /// - /// The relationship could not be determined, because the scores are not well-defined. - /// - Undetermined, - - /// - /// The first score vector dominates the second. - /// - FirstDominates, - - /// - /// The second score vector dominates the first. - /// - SecondDominates, - - /// - /// There is mutual non-dominance. - /// - NoDominance, - - /// - /// The two scores are equal. - /// - Equal, -} +namespace Draco.Compiler.Internal.Solver.OverloadResolution; /// /// Represents the score-vector for a single call. +/// Note, that this type is mutable, refining the scores as more data is available. /// internal readonly struct CallScore(int length) { - /// - /// True, if the score vector has a zero element. - /// - public bool HasZero => this.scores.Contains(0); - - /// - /// True, if the score vector is well-defined, meaning that there as no null scores. - /// - public bool IsWellDefined => !this.scores.Contains(null); - - /// - /// The length of this vector. - /// - public int Length => this.scores.Length; + private const int Undefined = ArgumentScore.Undefined; /// - /// The scores in this score vector. - /// - public int? this[int index] - { - get => this.scores[index]; - set - { - if (this.scores[index] is not null) throw new InvalidOperationException("can not modify non-null score"); - this.scores[index] = value; - } - } - - private readonly int?[] scores = new int?[length]; - - /// - /// Compares two call-scores. + /// Compares two call-scores of the same length. /// /// The first score to compare. /// The second score to compare. @@ -89,9 +32,9 @@ public static CallScoreComparison Compare(CallScore first, CallScore second) var firstScore = first[i]; var secondScore = second[i]; - if (firstScore is null || secondScore is null) return CallScoreComparison.Undetermined; + if (firstScore == Undefined || secondScore == Undefined) return CallScoreComparison.Undetermined; - var scoreComparison = firstScore.Value.CompareTo(secondScore.Value); + var scoreComparison = firstScore.CompareTo(secondScore); relation = relation switch { CallScoreComparison.FirstDominates when scoreComparison < 0 => CallScoreComparison.NoDominance, @@ -106,6 +49,37 @@ public static CallScoreComparison Compare(CallScore first, CallScore second) return relation; } + /// + /// Finds the dominating scores among a sequence of items. + /// + /// THe type of items to find the dominators aming. + /// The items to find the dominators among. + /// The score selector function. + /// The dominating elements among . + public static IEnumerable FindDominatorsBy( + IEnumerable items, + Func scoreSelector) + { + var candidates = items.ToList(); + + // Optimization, for a single or no candidate, don't bother + if (candidates.Count <= 1) return candidates; + + // We have more than one, find the max dominator + // NOTE: This might not be the actual dominator in case of mutual non-dominance + var dominatingScore = FindDominatingScore(candidates.Select(scoreSelector)); + // We keep every candidate that dominates this score, or there is mutual non-dominance + var dominatingCandidates = candidates + .Where(candidate => dominatingScore is null + || Compare(dominatingScore.Value, scoreSelector(candidate)) + is CallScoreComparison.Equal + or CallScoreComparison.NoDominance) + .ToList(); + + Debug.Assert(dominatingCandidates.Count > 0); + return dominatingCandidates; + } + /// /// Finds the best call score in a sequence. The result heavily depends on the order, in case there is mutual /// non-dominance. @@ -113,7 +87,7 @@ public static CallScoreComparison Compare(CallScore first, CallScore second) /// The scores to find the best in. /// The best score, or null if can't be determined because of an empty sequence or non-well-defined /// score vectors. - public static CallScore? FindBest(IEnumerable scores) + private static CallScore? FindDominatingScore(IEnumerable scores) { var enumerator = scores.GetEnumerator(); if (!enumerator.MoveNext()) return null; @@ -129,4 +103,35 @@ public static CallScoreComparison Compare(CallScore first, CallScore second) } return best; } + + /// + /// True, if the score vector has a zero element, in which case the call is a guaranteed no match. + /// + public bool HasZero => this.scores.Contains(0); + + /// + /// True, if the score vector is well-defined, meaning that there as no undefined scores. + /// + public bool IsWellDefined => !this.scores.Contains(Undefined); + + /// + /// The length of this vector. + /// + public int Length => this.scores.Length; + + /// + /// The scores in this score vector. + /// + public int this[int index] + { + get => this.scores[index]; + set + { + if (index < 0 || index >= this.Length) throw new ArgumentOutOfRangeException(nameof(index)); + if (this.scores[index] != Undefined) throw new InvalidOperationException("can not modify non-null score"); + this.scores[index] = value; + } + } + + private readonly int[] scores = Enumerable.Repeat(Undefined, length).ToArray(); } diff --git a/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallScoreComparison.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallScoreComparison.cs new file mode 100644 index 000000000..0f74854b3 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallScoreComparison.cs @@ -0,0 +1,32 @@ +namespace Draco.Compiler.Internal.Solver.OverloadResolution; + +/// +/// Represents the comparison result of two s. +/// +internal enum CallScoreComparison +{ + /// + /// The relationship could not be determined, because the scores are not well-defined. + /// + Undetermined, + + /// + /// The first score vector dominates the second. + /// + FirstDominates, + + /// + /// The second score vector dominates the first. + /// + SecondDominates, + + /// + /// There is mutual non-dominance. + /// + NoDominance, + + /// + /// The two scores are equal. + /// + Equal, +} diff --git a/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallUtilities.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallUtilities.cs new file mode 100644 index 000000000..edaeb5d6a --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/CallUtilities.cs @@ -0,0 +1,31 @@ +using Draco.Compiler.Internal.Symbols; + +namespace Draco.Compiler.Internal.Solver.OverloadResolution; + +/// +/// Utilities for calls and overload resolution. +/// +internal static class CallUtilities +{ + /// + /// Checks, if a function matches a given parameter count. + /// + /// The function to check. + /// The number of arguments passed in. + /// True, if can be called with number + /// of arguments. + public static bool MatchesParameterCount(FunctionSymbol function, int argc) + { + // Exact count match is always eligibe by only param count + if (function.Parameters.Length == argc) return true; + // If not variadic, we do need an exact match + if (!function.IsVariadic) return false; + // Otherise, there must be one less, exactly as many, or more arguments + // - one less means nullary variadics + // - exact match is one variadic + // - more is more variadics + if (argc + 1 >= function.Parameters.Length) return true; + // No match + return false; + } +} diff --git a/src/Draco.Compiler/Internal/Solver/OverloadResolution/OverloadCandidateSet.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/OverloadCandidateSet.cs new file mode 100644 index 000000000..1a19a6f99 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/OverloadCandidateSet.cs @@ -0,0 +1,106 @@ +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Draco.Compiler.Internal.Symbols; + +namespace Draco.Compiler.Internal.Solver.OverloadResolution; + +/// +/// A set of overload candidates. +/// +internal readonly struct OverloadCandidateSet : IReadOnlyCollection> +{ + /// + /// Constructs a new set of overload candidates. + /// + /// The candidate functions. + /// The arguments the candidates are called with. + /// The constructed candidate set. + public static OverloadCandidateSet Create( + IEnumerable candidates, + IEnumerable arguments) + { + var argList = arguments.ToImmutableArray(); + var remainingCandidates = candidates + .Where(c => CallUtilities.MatchesParameterCount(c, argList.Length)) + .Select(CallCandidate.Create) + .ToList(); + return new(remainingCandidates, argList); + } + + /// + /// The arguments the candidates were called with. + /// + public ImmutableArray Arguments { get; } + + /// + /// The remaining candidates. + /// + public IEnumerable> Candidates => this.candidates; + + /// + /// True, if the set is well defined, meaning that there is no need to further refine the candidates. + /// This can mean that there is only one candidate left, no candidates are left, or even + /// that the remaining candidates are well defined, meaning ambiguity. + /// + public bool IsWellDefined => this.Count <= 1 + || this.candidates.All(c => c.IsWellDefined); + + /// + /// True, if the set is ambiguous, meaning there are multiple remaining candidates. + /// + public bool IsAmbiguous => this.Count > 1; + + public int Count => this.candidates.Count; + + private readonly List> candidates; + + private OverloadCandidateSet( + List> candidates, + ImmutableArray arguments) + { + this.Arguments = arguments; + this.candidates = candidates; + } + + /// + /// Refines the candidates by scoring the arguments and eliminating any invalid candidates. + /// + /// True, if the refinement changed the set of candidates in some way. + public bool Refine() + { + var changed = false; + while (this.RefineOnce()) changed = true; + return changed; + } + + private bool RefineOnce() + { + var changed = false; + // Iterate through all candidates + for (var i = 0; i < this.Count;) + { + var candidate = this.candidates[i]; + + // Compute any undefined arguments + changed = candidate.Refine(this.Arguments) || changed; + + // Remove eliminated candidates + if (candidate.IsEliminated) + { + this.candidates.RemoveAt(i); + changed = true; + } + else + { + // Otherwise it stays + ++i; + } + } + return changed; + } + + public IEnumerator> GetEnumerator() => this.candidates.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); +} diff --git a/src/Draco.LanguageServer/Capabilities/Rename.cs b/src/Draco.LanguageServer/Capabilities/Rename.cs index f498cef8a..c88deb252 100644 --- a/src/Draco.LanguageServer/Capabilities/Rename.cs +++ b/src/Draco.LanguageServer/Capabilities/Rename.cs @@ -57,7 +57,7 @@ internal partial class DracoLanguageServer : IRename { Changes = textEdits.ToDictionary( e => new DocumentUri(e.Path!.LocalPath), - e => e.Edits.ToList() as IList), + e => e.Edits.Cast().ToList() as IList), }); } diff --git a/src/Draco.LanguageServer/Capabilities/TextDocumentSync.cs b/src/Draco.LanguageServer/Capabilities/TextDocumentSync.cs index 41c092124..be9971841 100644 --- a/src/Draco.LanguageServer/Capabilities/TextDocumentSync.cs +++ b/src/Draco.LanguageServer/Capabilities/TextDocumentSync.cs @@ -1,4 +1,3 @@ -using System; using System.Linq; using System.Threading.Tasks; using Draco.Lsp.Model; diff --git a/src/Draco.ProjectSystem/DesignTimeBuild.cs b/src/Draco.ProjectSystem/DesignTimeBuild.cs index edcef4e43..81ba17bef 100644 --- a/src/Draco.ProjectSystem/DesignTimeBuild.cs +++ b/src/Draco.ProjectSystem/DesignTimeBuild.cs @@ -1,7 +1,6 @@ using System.Collections.Immutable; using System.IO; using System.Linq; -using MSBuildProject = Microsoft.Build.Evaluation.Project; using MSBuildProjectInstance = Microsoft.Build.Execution.ProjectInstance; namespace Draco.ProjectSystem; diff --git a/src/Draco.ProjectSystem/Draco.ProjectSystem.csproj b/src/Draco.ProjectSystem/Draco.ProjectSystem.csproj index 1b5c250e8..4256a2860 100644 --- a/src/Draco.ProjectSystem/Draco.ProjectSystem.csproj +++ b/src/Draco.ProjectSystem/Draco.ProjectSystem.csproj @@ -6,6 +6,6 @@ - + diff --git a/src/Draco.ProjectSystem/Project.cs b/src/Draco.ProjectSystem/Project.cs index 55e158da7..5405ef9ba 100644 --- a/src/Draco.ProjectSystem/Project.cs +++ b/src/Draco.ProjectSystem/Project.cs @@ -1,10 +1,8 @@ using System; using System.Collections.Generic; -using System.Collections.Immutable; using System.IO; using System.Linq; using Microsoft.Build.Evaluation; -using Microsoft.Build.Execution; using MSBuildProject = Microsoft.Build.Evaluation.Project; using MSBuildProjectInstance = Microsoft.Build.Execution.ProjectInstance; diff --git a/src/Draco.SourceGeneration/Draco.SourceGeneration.csproj b/src/Draco.SourceGeneration/Draco.SourceGeneration.csproj index 5689cc28b..bde8e3777 100644 --- a/src/Draco.SourceGeneration/Draco.SourceGeneration.csproj +++ b/src/Draco.SourceGeneration/Draco.SourceGeneration.csproj @@ -17,7 +17,7 @@ - +