From 8dc28760ac4d697ce2c12de3af5035433409182f Mon Sep 17 00:00:00 2001 From: LPeter1997 Date: Mon, 12 Aug 2024 17:15:02 +0200 Subject: [PATCH] Fixing type bugs (#423) * Create OverloadingTests.cs * Update OverloadingTests.cs * Added cache to type provider * Cache base types and dedup them * Missing stuff * Update ConstraintSolver_Rules.cs * Update TypeProvider.cs * Update ConstraintSolver.cs * Shuffling logic around * Update SymbolEqualityComparer.cs * Fleshed out hashing a bit --- .../Semantics/OverloadingTests.cs | 225 ++++++++++++++++++ .../Internal/Solver/ConstraintSolver.cs | 102 -------- .../Solver/ConstraintSolver_Operations.cs | 137 +++++++++++ .../Internal/Solver/ConstraintSolver_Rules.cs | 6 +- .../OverloadResolution/ArgumentScore.cs | 16 +- .../Internal/Symbols/Metadata/TypeProvider.cs | 95 ++++++-- .../Symbols/SymbolEqualityComparer.cs | 31 ++- .../Internal/Symbols/TypeSymbol.cs | 17 +- 8 files changed, 486 insertions(+), 143 deletions(-) create mode 100644 src/Draco.Compiler.Tests/Semantics/OverloadingTests.cs create mode 100644 src/Draco.Compiler/Internal/Solver/ConstraintSolver_Operations.cs diff --git a/src/Draco.Compiler.Tests/Semantics/OverloadingTests.cs b/src/Draco.Compiler.Tests/Semantics/OverloadingTests.cs new file mode 100644 index 000000000..1ef5f4fdc --- /dev/null +++ b/src/Draco.Compiler.Tests/Semantics/OverloadingTests.cs @@ -0,0 +1,225 @@ +using System.Diagnostics; +using Draco.Compiler.Api; +using Draco.Compiler.Api.Syntax; +using Draco.Compiler.Internal.Symbols; +using static Draco.Compiler.Api.Syntax.SyntaxFactory; + +namespace Draco.Compiler.Tests.Semantics; + +public sealed class OverloadingTests : SemanticTestsBase +{ + // func foo(l: List) {} + // func foo(l: List) {} + // func foo(l: List) {} + // func foo(l: IEnumerable) {} + // func foo(l: IEnumerable) {} + // func foo(l: IEnumerable) {} + private static IEnumerable GetGenericListOverloads() => [ + FunctionDeclaration( + "foo", + ParameterList(Parameter("l", GenericType(NameType("List"), NameType("int32")))), + null, + BlockFunctionBody()), + FunctionDeclaration( + "foo", + ParameterList(Parameter("l", GenericType(NameType("List"), NameType("string")))), + null, + BlockFunctionBody()), + FunctionDeclaration( + "foo", + GenericParameterList(GenericParameter("T")), + ParameterList(Parameter("l", GenericType(NameType("List"), NameType("T")))), + null, + BlockFunctionBody()), + FunctionDeclaration( + "foo", + ParameterList(Parameter("l", GenericType(NameType("IEnumerable"), NameType("int32")))), + null, + BlockFunctionBody()), + FunctionDeclaration( + "foo", + ParameterList(Parameter("l", GenericType(NameType("IEnumerable"), NameType("string")))), + null, + BlockFunctionBody()), + FunctionDeclaration( + "foo", + GenericParameterList(GenericParameter("T")), + ParameterList(Parameter("l", GenericType(NameType("IEnumerable"), NameType("T")))), + null, + BlockFunctionBody()), + ]; + + // import System.Collections.Generic; + // import System.Linq.Enumerable; + // + // ... foo overloads ... + // + // func main() { + // call(); + // + private static SyntaxTree CreateOverloadTree(CallExpressionSyntax call) => SyntaxTree.Create(CompilationUnit( + new DeclarationSyntax[] + { + ImportDeclaration("System", "Collections", "Generic"), + ImportDeclaration("System", "Linq", "Enumerable"), + }.Concat(GetGenericListOverloads()) + .Append(FunctionDeclaration( + "main", + ParameterList(), + null, + BlockFunctionBody(ExpressionStatement(call)))))); + + private static FunctionSymbol GetDeclaredFunctionSymbol(Compilation compilation, int index) + { + var syntaxTree = compilation.SyntaxTrees.Single(); + var syntax = syntaxTree.FindInChildren(index); + Debug.Assert(syntax is not null); + + var semanticModel = compilation.GetSemanticModel(syntaxTree); + var symbol = semanticModel.GetDeclaredSymbol(syntax); + Debug.Assert(symbol!.Name == "foo"); + + return GetInternalSymbol(symbol); + } + + private static FunctionSymbol GetCalledFunctionSymbol(Compilation compilation) + { + var syntaxTree = compilation.SyntaxTrees.Single(); + var syntax = syntaxTree.FindInChildren(); + Debug.Assert(syntax is not null); + + var semanticModel = compilation.GetSemanticModel(syntaxTree); + var symbol = semanticModel.GetReferencedSymbol(syntax); + Debug.Assert(symbol!.Name == "foo"); + + return GetInternalSymbol(symbol); + } + + [Fact] + public void ListInt32Overload() + { + // foo(List()); + + // Arrange + var tree = CreateOverloadTree(CallExpression( + NameExpression("foo"), + CallExpression(GenericExpression(NameExpression("List"), NameType("int32"))))); + + // Act + var compilation = CreateCompilation(tree); + var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 0); + var calledSymbol = GetCalledFunctionSymbol(compilation); + + // Assert + Assert.Empty(compilation.Diagnostics); + Assert.Same(expectedSymbol, calledSymbol); + } + + [Fact] + public void ListStringOverload() + { + // foo(List()); + + // Arrange + var tree = CreateOverloadTree(CallExpression( + NameExpression("foo"), + CallExpression(GenericExpression(NameExpression("List"), NameType("string"))))); + var compilation = CreateCompilation(tree); + + // Act + var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 1); + var calledSymbol = GetCalledFunctionSymbol(compilation); + + // Assert + Assert.Empty(compilation.Diagnostics); + Assert.Same(expectedSymbol, calledSymbol); + } + + [Fact] + public void ListGenericOverload() + { + // foo(List()); + + // Arrange + var tree = CreateOverloadTree(CallExpression( + NameExpression("foo"), + CallExpression(GenericExpression(NameExpression("List"), NameType("bool"))))); + var compilation = CreateCompilation(tree); + + // Act + var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 2); + var calledSymbol = GetCalledFunctionSymbol(compilation); + + // Assert + Assert.Empty(compilation.Diagnostics); + Assert.True(calledSymbol.IsGenericInstance); + Assert.Same(expectedSymbol, calledSymbol.GenericDefinition); + } + + [Fact] + public void IEnumerableInt32Overload() + { + // foo(AsEnumerable(List())); + + // Arrange + var tree = CreateOverloadTree(CallExpression( + NameExpression("foo"), + CallExpression( + NameExpression("AsEnumerable"), + CallExpression(GenericExpression(NameExpression("List"), NameType("int32")))))); + var compilation = CreateCompilation(tree); + + // Act + var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 3); + var calledSymbol = GetCalledFunctionSymbol(compilation); + + // Assert + Assert.Empty(compilation.Diagnostics); + Assert.Same(expectedSymbol, calledSymbol); + } + + [Fact] + public void IEnumerableStringOverload() + { + // foo(AsEnumerable(List())); + + // Arrange + var tree = CreateOverloadTree(CallExpression( + NameExpression("foo"), + CallExpression( + NameExpression("AsEnumerable"), + CallExpression(GenericExpression(NameExpression("List"), NameType("string")))))); + var compilation = CreateCompilation(tree); + + // Act + var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 4); + var calledSymbol = GetCalledFunctionSymbol(compilation); + + // Assert + Assert.Empty(compilation.Diagnostics); + Assert.Same(expectedSymbol, calledSymbol); + } + + [Fact] + public void IEnumerableGenericOverload() + { + // foo(AsEnumerable(List())); + + // Arrange + var tree = CreateOverloadTree(CallExpression( + NameExpression("foo"), + CallExpression( + NameExpression("AsEnumerable"), + CallExpression(GenericExpression(NameExpression("List"), NameType("bool")))))); + var compilation = CreateCompilation(tree); + + // Act + var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 5); + var calledSymbol = GetCalledFunctionSymbol(compilation); + + // Assert + Assert.Empty(compilation.Diagnostics); + Assert.True(calledSymbol.IsGenericInstance); + Assert.Same(expectedSymbol, calledSymbol.GenericDefinition); + } +} diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs index 975810fa1..06090b6ae 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs @@ -150,106 +150,4 @@ public TypeVariable AllocateTypeVariable(bool track = true) if (track) this.typeVariables.Add(typeVar); return typeVar; } - - /// - /// Unwraps the potential type-variable until it is a non-type-variable type. - /// - /// The type to unwrap. - /// The task that completes when is subsituted as a non-type-variable. - public static async SolverTask Substituted(TypeSymbol type) - { - while (type is TypeVariable tv) type = await tv.Substituted; - return type; - } - - /// - /// Unifies two types, asserting their success. - /// - /// The first type to unify. - /// The second type to unify. - public static void UnifyAsserted(TypeSymbol first, TypeSymbol second) - { - if (Unify(first, second)) return; - throw new System.InvalidOperationException($"could not unify {first} and {second}"); - } - - /// - /// Attempts to unify two types. - /// - /// The first type to unify. - /// The second type to unify. - /// True, if unification was successful, false otherwise. - public static bool Unify(TypeSymbol first, TypeSymbol second) - { - first = first.Substitution; - second = second.Substitution; - - // NOTE: Referential equality is OK here, we don't need to use SymbolEqualityComparer, this is unification - if (ReferenceEquals(first, second)) return true; - - switch (first, second) - { - // Type variable substitution takes priority - // so it can unify with never type and error type to stop type errors from cascading - case (TypeVariable v1, TypeVariable v2): - { - // Check for circularity - // NOTE: Referential equality is OK here, we are checking for CIRCULARITY - // which is referential check - if (ReferenceEquals(v1, v2)) return true; - v1.Substitute(v2); - return true; - } - case (TypeVariable v, TypeSymbol other): - { - v.Substitute(other); - return true; - } - case (TypeSymbol other, TypeVariable v): - { - v.Substitute(other); - return true; - } - - // Never type is never reached, unifies with everything - case (NeverTypeSymbol, _): - case (_, NeverTypeSymbol): - // Error type unifies with everything to avoid cascading type errors - case (ErrorTypeSymbol, _): - case (_, ErrorTypeSymbol): - return true; - - case (ArrayTypeSymbol a1, ArrayTypeSymbol a2) when a1.IsGenericDefinition && a2.IsGenericDefinition: - return a1.Rank == a2.Rank; - - // NOTE: Primitives are filtered out already, along with metadata types - - case (FunctionTypeSymbol f1, FunctionTypeSymbol f2): - { - if (f1.Parameters.Length != f2.Parameters.Length) return false; - for (var i = 0; i < f1.Parameters.Length; ++i) - { - if (!Unify(f1.Parameters[i].Type, f2.Parameters[i].Type)) return false; - } - return Unify(f1.ReturnType, f2.ReturnType); - } - - case (_, _) when first.IsGenericInstance && second.IsGenericInstance: - { - // NOTE: Generic instances might not obey referential equality - Debug.Assert(first.GenericDefinition is not null); - Debug.Assert(second.GenericDefinition is not null); - if (first.GenericArguments.Length != second.GenericArguments.Length) return false; - if (!Unify(first.GenericDefinition, second.GenericDefinition)) return false; - for (var i = 0; i < first.GenericArguments.Length; ++i) - { - if (!Unify(first.GenericArguments[i], second.GenericArguments[i])) return false; - } - return true; - } - - default: - return false; - } - } } diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Operations.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Operations.cs new file mode 100644 index 000000000..114021247 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Operations.cs @@ -0,0 +1,137 @@ +using Draco.Compiler.Internal.Symbols.Error; +using Draco.Compiler.Internal.Symbols.Synthetized; +using Draco.Compiler.Internal.Symbols; +using System.Diagnostics; +using Draco.Compiler.Internal.Solver.Constraints; +using System; +using System.Linq; + +namespace Draco.Compiler.Internal.Solver; + +internal sealed partial class ConstraintSolver +{ + /// + /// Unifies two types, asserting their success. + /// + /// The first type to unify. + /// The second type to unify. + public static void UnifyAsserted(TypeSymbol first, TypeSymbol second) + { + if (Unify(first, second)) return; + throw new System.InvalidOperationException($"could not unify {first} and {second}"); + } + + /// + /// Attempts to unify two types. + /// + /// The first type to unify. + /// The second type to unify. + /// True, if unification was successful, false otherwise. + public static bool Unify(TypeSymbol first, TypeSymbol second) + { + first = first.Substitution; + second = second.Substitution; + + // NOTE: Referential equality is OK here, we don't need to use SymbolEqualityComparer, this is unification + if (ReferenceEquals(first, second)) return true; + + switch (first, second) + { + // Type variable substitution takes priority + // so it can unify with never type and error type to stop type errors from cascading + case (TypeVariable v1, TypeVariable v2): + { + // Check for circularity + // NOTE: Referential equality is OK here, we are checking for CIRCULARITY + // which is referential check + if (ReferenceEquals(v1, v2)) return true; + v1.Substitute(v2); + return true; + } + case (TypeVariable v, TypeSymbol other): + { + v.Substitute(other); + return true; + } + case (TypeSymbol other, TypeVariable v): + { + v.Substitute(other); + return true; + } + + // Never type is never reached, unifies with everything + case (NeverTypeSymbol, _): + case (_, NeverTypeSymbol): + // Error type unifies with everything to avoid cascading type errors + case (ErrorTypeSymbol, _): + case (_, ErrorTypeSymbol): + return true; + + case (ArrayTypeSymbol a1, ArrayTypeSymbol a2) when a1.IsGenericDefinition && a2.IsGenericDefinition: + return a1.Rank == a2.Rank; + + // NOTE: Primitives are filtered out already, along with metadata types + + case (FunctionTypeSymbol f1, FunctionTypeSymbol f2): + { + if (f1.Parameters.Length != f2.Parameters.Length) return false; + for (var i = 0; i < f1.Parameters.Length; ++i) + { + if (!Unify(f1.Parameters[i].Type, f2.Parameters[i].Type)) return false; + } + return Unify(f1.ReturnType, f2.ReturnType); + } + + case (_, _) when first.IsGenericInstance && second.IsGenericInstance: + { + // NOTE: Generic instances might not obey referential equality + Debug.Assert(first.GenericDefinition is not null); + Debug.Assert(second.GenericDefinition is not null); + if (first.GenericArguments.Length != second.GenericArguments.Length) return false; + if (!Unify(first.GenericDefinition, second.GenericDefinition)) return false; + for (var i = 0; i < first.GenericArguments.Length; ++i) + { + if (!Unify(first.GenericArguments[i], second.GenericArguments[i])) return false; + } + return true; + } + + default: + return false; + } + } + + public static void AssignAsserted(TypeSymbol targetType, TypeSymbol assignedType) + { + if (Assign(targetType, assignedType)) return; + throw new System.InvalidOperationException($"could not assign {assignedType} to {targetType}"); + } + + private static bool Assign(TypeSymbol targetType, TypeSymbol assignedType) + { + targetType = targetType.Substitution; + assignedType = assignedType.Substitution; + + if (targetType.IsGenericInstance && assignedType.IsGenericInstance) + { + // We need to look for the base type + var targetGenericDefinition = targetType.GenericDefinition!; + + var assignedToUnify = assignedType.BaseTypes + .FirstOrDefault(t => SymbolEqualityComparer.Default.Equals(t.GenericDefinition, targetGenericDefinition)); + if (assignedToUnify is null) + { + // TODO + throw new NotImplementedException(); + } + + // Unify + return Unify(targetType, assignedToUnify); + } + else + { + // TODO: Might not be correct + return Unify(targetType, assignedType); + } + } +} diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs index b0edaa20b..cd759f0b9 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs @@ -372,11 +372,7 @@ private IEnumerable ConstructRules(DiagnosticBag diagnostics) => [ // As a last-last effort, we assume that a singular assignment means exact matching types Simplification(typeof(Assignable)) .Body((ConstraintStore store, Assignable assignable) => - { - // TODO: Is asserted correct here? - // Maybe just for type-variables? - UnifyAsserted(assignable.TargetType, assignable.AssignedType); - }) + AssignAsserted(assignable.TargetType, assignable.AssignedType)) .Named("sole_assignable"), // As a last-effort, if we see a common ancestor constraint with a single non-type-var, we diff --git a/src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs b/src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs index c7943e07e..8880f3d7d 100644 --- a/src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs +++ b/src/Draco.Compiler/Internal/Solver/OverloadResolution/ArgumentScore.cs @@ -20,9 +20,9 @@ internal static class ArgumentScore /// /// Maximum score for a full match. /// - public const int FullScore = 16; + public const int FullScore = 32; - private const int HalfScore = 8; + private const int HalfScore = FullScore / 2; private const int ZeroScore = 0; /// @@ -91,20 +91,24 @@ private static int ScoreArgument(TypeSymbol paramType, TypeSymbol argType) // Base type match is half score if (SymbolEqualityComparer.Default.IsBaseOf(paramType, argType)) return HalfScore; - // TODO: Unspecified what happens for generics - // For now we require an exact match and score is the lowest score among generic args + // For generics we take the lowest scoring argument and half it, if the generic type is a base if (paramType.IsGenericInstance && argType.IsGenericInstance) { var paramGenericDefinition = paramType.GenericDefinition!; var argGenericDefinition = argType.GenericDefinition!; - if (!SymbolEqualityComparer.Default.Equals(paramGenericDefinition, argGenericDefinition)) return ZeroScore; + var genericDefinitionIsExact = SymbolEqualityComparer.Default.Equals(paramGenericDefinition, argGenericDefinition); + var genericDefinitionIsBase = SymbolEqualityComparer.Default.IsBaseOf(paramGenericDefinition, argGenericDefinition); + + if (!genericDefinitionIsExact && !genericDefinitionIsBase) return ZeroScore; Debug.Assert(paramType.GenericArguments.Length == argType.GenericArguments.Length); - return paramType.GenericArguments + var minGenericScore = paramType.GenericArguments .Zip(argType.GenericArguments) .Select(pair => ScoreArgument(pair.First, pair.Second)) .Min(); + + return genericDefinitionIsExact ? minGenericScore : minGenericScore / 2; } // Type parameter match is half score diff --git a/src/Draco.Compiler/Internal/Symbols/Metadata/TypeProvider.cs b/src/Draco.Compiler/Internal/Symbols/Metadata/TypeProvider.cs index 08fa5ec37..d04c17a1c 100644 --- a/src/Draco.Compiler/Internal/Symbols/Metadata/TypeProvider.cs +++ b/src/Draco.Compiler/Internal/Symbols/Metadata/TypeProvider.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; @@ -14,13 +15,26 @@ namespace Draco.Compiler.Internal.Symbols.Metadata; internal sealed class TypeProvider(Compilation compilation) : ISignatureTypeProvider, ICustomAttributeTypeProvider { - private readonly record struct CacheKey(MetadataReader Reader, EntityHandle Handle); + // We have 2 levels of caching to avoid re-creating types + // The first level is the "outer" level, which caches types by their metadata handle + // The second level is the "inner" level, which caches types by their fully qualified name + // Generally the first level is faster, but the second level is necessary for cross-assembly types and + // different type reference encodings + + private readonly record struct LightCacheKey( + MetadataReader Reader, + EntityHandle Handle); + + private readonly record struct CacheKey( + string AssemblyFullName, + string TypeFullyQualifiedName); // TODO: We return a special error type for now to swallow errors private static TypeSymbol UnknownType { get; } = new PrimitiveTypeSymbol("", false); private WellKnownTypes WellKnownTypes => compilation.WellKnownTypes; + private readonly ConcurrentDictionary lightCache = new(); private readonly ConcurrentDictionary cache = new(); public TypeSymbol GetArrayType(TypeSymbol elementType, ArrayShape shape) => @@ -88,10 +102,45 @@ public TypeSymbol GetGenericTypeParameter(Symbol genericContext, int index) public TypeSymbol GetTypeFromDefinition(MetadataReader reader, TypeDefinitionHandle handle, byte rawTypeKind) { - var key = new CacheKey(reader, handle); - return this.cache.GetOrAdd(key, _ => this.BuildTypeFromDefinition(reader, handle, rawTypeKind)); + var lightKey = new LightCacheKey(reader, handle); + return this.lightCache.GetOrAdd(lightKey, _ => + { + // Check, if the type is already cached in the primary cache + // For that we need to resolve the assembly name and the fully qualified name + var assemblyName = reader.GetAssemblyDefinition().GetAssemblyName(); + + var definition = reader.GetTypeDefinition(handle); + var @namespace = reader.GetString(definition.Namespace); + var name = reader.GetString(definition.Name); + var fullName = ConcatenateNamespaceAndName(@namespace, name); + + var key = new CacheKey(assemblyName.FullName, fullName); + return this.cache.GetOrAdd(key, _ => this.BuildTypeFromDefinition(reader, handle, rawTypeKind)); + }); } + public TypeSymbol GetTypeFromReference(MetadataReader reader, TypeReferenceHandle handle, byte rawTypeKind) + { + var lightKey = new LightCacheKey(reader, handle); + return this.lightCache.GetOrAdd(lightKey, _ => + { + var key = BuildCacheKey(reader, handle); + return this.cache.GetOrAdd(key, _ => this.BuildTypeFromReference(reader, handle, rawTypeKind)); + }); + } + + // TODO: Should we cache this as well? doesn't seem to have any effect + public TypeSymbol GetTypeFromSpecification(MetadataReader reader, Symbol genericContext, TypeSpecificationHandle handle, byte rawTypeKind) + { + var specification = reader.GetTypeSpecification(handle); + return specification.DecodeSignature(this, genericContext); + } + + public TypeSymbol GetSystemType() => this.WellKnownTypes.SystemType; + public bool IsSystemType(TypeSymbol type) => ReferenceEquals(type, this.WellKnownTypes.SystemType); + public TypeSymbol GetTypeFromSerializedName(string name) => UnknownType; + public PrimitiveTypeCode GetUnderlyingEnumType(TypeSymbol type) => throw new System.ArgumentOutOfRangeException(nameof(type)); + private TypeSymbol BuildTypeFromDefinition(MetadataReader reader, TypeDefinitionHandle handle, byte rawTypeKind) { var definition = reader.GetTypeDefinition(handle); @@ -118,23 +167,18 @@ private TypeSymbol BuildTypeFromDefinition(MetadataReader reader, TypeDefinition // Type path var @namespace = reader.GetString(definition.Namespace); var name = reader.GetString(definition.Name); - var fullName = string.IsNullOrWhiteSpace(@namespace) ? name : $"{@namespace}.{name}"; + var fullName = ConcatenateNamespaceAndName(@namespace, name); var path = fullName.Split('.').ToImmutableArray(); return this.WellKnownTypes.GetTypeFromAssembly(assemblyName, path); } - public TypeSymbol GetTypeFromReference(MetadataReader reader, TypeReferenceHandle handle, byte rawTypeKind) - { - var key = new CacheKey(reader, handle); - return this.cache.GetOrAdd(key, _ => this.BuildTypeFromReference(reader, handle, rawTypeKind)); - } - private TypeSymbol BuildTypeFromReference(MetadataReader reader, TypeReferenceHandle handle, byte rawTypeKind) { var parts = new List(); var reference = reader.GetTypeReference(handle); - parts.Add(reader.GetString(reference.Name)); + var referenceName = reader.GetString(reference.Name); + parts.Add(referenceName); EntityHandle resolutionScope; for (resolutionScope = reference.ResolutionScope; resolutionScope.Kind == HandleKind.TypeReference; resolutionScope = reference.ResolutionScope) { @@ -151,15 +195,28 @@ private TypeSymbol BuildTypeFromReference(MetadataReader reader, TypeReferenceHa return assembly.RootNamespace.Lookup([.. parts]).OfType().Single(); } - // TODO: Should we cache this as well? doesn't seem to have any effect - public TypeSymbol GetTypeFromSpecification(MetadataReader reader, Symbol genericContext, TypeSpecificationHandle handle, byte rawTypeKind) + private static CacheKey BuildCacheKey(MetadataReader reader, TypeReferenceHandle handle) { - var specification = reader.GetTypeSpecification(handle); - return specification.DecodeSignature(this, genericContext); + // Directly the type itself + var typeReference = reader.GetTypeReference(handle); + var typeName = reader.GetString(typeReference.Name); + // The reference might be nested, so we need to walk up the resolution scope chain + var scope = typeReference.ResolutionScope; + while (scope.Kind == HandleKind.TypeReference) + { + var parentReference = reader.GetTypeReference((TypeReferenceHandle)scope); + typeName = $"{reader.GetString(parentReference.Name)}.{typeName}"; + scope = parentReference.ResolutionScope; + } + // Build full name + var @namespace = reader.GetString(typeReference.Namespace); + var fullName = ConcatenateNamespaceAndName(@namespace, typeName); + // Resolve assembly name + var assemblyName = reader.GetAssemblyReference((AssemblyReferenceHandle)scope).GetAssemblyName(); + // Construct key + return new CacheKey(assemblyName.FullName, fullName); } - public TypeSymbol GetSystemType() => this.WellKnownTypes.SystemType; - public bool IsSystemType(TypeSymbol type) => ReferenceEquals(type, this.WellKnownTypes.SystemType); - public TypeSymbol GetTypeFromSerializedName(string name) => UnknownType; - public PrimitiveTypeCode GetUnderlyingEnumType(TypeSymbol type) => throw new System.ArgumentOutOfRangeException(nameof(type)); + private static string ConcatenateNamespaceAndName(string? @namespace, string name) => + string.IsNullOrWhiteSpace(@namespace) ? name : $"{@namespace}.{name}"; } diff --git a/src/Draco.Compiler/Internal/Symbols/SymbolEqualityComparer.cs b/src/Draco.Compiler/Internal/Symbols/SymbolEqualityComparer.cs index 1906d36cf..2bae80d0d 100644 --- a/src/Draco.Compiler/Internal/Symbols/SymbolEqualityComparer.cs +++ b/src/Draco.Compiler/Internal/Symbols/SymbolEqualityComparer.cs @@ -92,6 +92,22 @@ public bool Equals(TypeSymbol? x, TypeSymbol? y) return x.GenericArguments.SequenceEqual(y.GenericArguments, this); } + // TODO: Does this check belong here? + if (x.IsGenericDefinition && IsUnboundedGenericInstance(y)) + { + // Check, if x is a generic bound instance, meaning its arguments are also generic types + // TODO: This might be a nice place to check constraints in the future too? + return this.Equals(x, y.GenericDefinition); + } + + // TODO: Does this check belong here? + if (y.IsGenericDefinition && IsUnboundedGenericInstance(x)) + { + // Check, if x is a generic bound instance, meaning its arguments are also generic types + // TODO: This might be a nice place to check constraints in the future too? + return this.Equals(x.GenericDefinition, y); + } + return (x, y) switch { (ArrayTypeSymbol a1, ArrayTypeSymbol a2) @@ -117,10 +133,16 @@ public int GetHashCode([DisallowNull] TypeSymbol obj) { obj = this.Unwrap(obj); - return obj switch + if (obj.IsGenericInstance) { - _ => RuntimeHelpers.GetHashCode(obj), - }; + // Combine the hash code of the generic definition with the hash codes of the arguments + var hash = default(HashCode); + hash.Add(obj.GenericDefinition); + foreach (var arg in obj.GenericArguments) hash.Add(arg); + return hash.ToHashCode(); + } + + return RuntimeHelpers.GetHashCode(obj); } [return: NotNullIfNotNull(nameof(type))] @@ -135,4 +157,7 @@ public int GetHashCode([DisallowNull] TypeSymbol obj) } return unwrappedType; } + + private static bool IsUnboundedGenericInstance(TypeSymbol t) => + t.IsGenericInstance && t.GenericArguments.All(a => a.Substitution is TypeParameterSymbol); } diff --git a/src/Draco.Compiler/Internal/Symbols/TypeSymbol.cs b/src/Draco.Compiler/Internal/Symbols/TypeSymbol.cs index ac6005200..b38000242 100644 --- a/src/Draco.Compiler/Internal/Symbols/TypeSymbol.cs +++ b/src/Draco.Compiler/Internal/Symbols/TypeSymbol.cs @@ -1,7 +1,9 @@ +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using Draco.Compiler.Internal.Symbols.Generic; +using Draco.Compiler.Internal.Utilities; namespace Draco.Compiler.Internal.Symbols; @@ -60,14 +62,8 @@ internal abstract partial class TypeSymbol : Symbol, IMemberSymbol /// All types that can be considered the base type of this one, including this type itself. /// The types are returned in a pre-order manner, starting from this type. /// - public IEnumerable BaseTypes - { - get - { - yield return this; - foreach (var t in this.ImmediateBaseTypes.SelectMany(b => b.BaseTypes)) yield return t; - } - } + public IEnumerable BaseTypes => InterlockedUtils.InitializeDefault(ref this.baseTypes, this.BuildBaseTypes); + private ImmutableArray baseTypes; /// /// The members defined directly in this type doesn't include members from . @@ -139,6 +135,11 @@ private ImmutableArray BuildMembers() return builder.ToImmutable(); } + private ImmutableArray BuildBaseTypes() => GraphTraversal.DepthFirst( + start: this, + getNeighbors: s => s.ImmediateBaseTypes, + comparer: SymbolEqualityComparer.AllowTypeVariables).ToImmutableArray(); + public override TypeSymbol GenericInstantiate(Symbol? containingSymbol, ImmutableArray arguments) => (TypeSymbol)base.GenericInstantiate(containingSymbol, arguments); public override TypeSymbol GenericInstantiate(Symbol? containingSymbol, GenericContext context)