Skip to content

Commit

Permalink
Fixing type bugs (#423)
Browse files Browse the repository at this point in the history
* Create OverloadingTests.cs

* Update OverloadingTests.cs

* Added cache to type provider

* Cache base types and dedup them

* Missing stuff

* Update ConstraintSolver_Rules.cs

* Update TypeProvider.cs

* Update ConstraintSolver.cs

* Shuffling logic around

* Update SymbolEqualityComparer.cs

* Fleshed out hashing a bit
  • Loading branch information
LPeter1997 authored Aug 12, 2024
1 parent 42b368f commit 8dc2876
Show file tree
Hide file tree
Showing 8 changed files with 486 additions and 143 deletions.
225 changes: 225 additions & 0 deletions src/Draco.Compiler.Tests/Semantics/OverloadingTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
using System.Diagnostics;
using Draco.Compiler.Api;
using Draco.Compiler.Api.Syntax;
using Draco.Compiler.Internal.Symbols;
using static Draco.Compiler.Api.Syntax.SyntaxFactory;

namespace Draco.Compiler.Tests.Semantics;

public sealed class OverloadingTests : SemanticTestsBase
{
// func foo(l: List<int32>) {}
// func foo(l: List<string>) {}
// func foo<T>(l: List<T>) {}
// func foo(l: IEnumerable<int32>) {}
// func foo(l: IEnumerable<string>) {}
// func foo<T>(l: IEnumerable<T>) {}
private static IEnumerable<DeclarationSyntax> GetGenericListOverloads() => [
FunctionDeclaration(
"foo",
ParameterList(Parameter("l", GenericType(NameType("List"), NameType("int32")))),
null,
BlockFunctionBody()),
FunctionDeclaration(
"foo",
ParameterList(Parameter("l", GenericType(NameType("List"), NameType("string")))),
null,
BlockFunctionBody()),
FunctionDeclaration(
"foo",
GenericParameterList(GenericParameter("T")),
ParameterList(Parameter("l", GenericType(NameType("List"), NameType("T")))),
null,
BlockFunctionBody()),
FunctionDeclaration(
"foo",
ParameterList(Parameter("l", GenericType(NameType("IEnumerable"), NameType("int32")))),
null,
BlockFunctionBody()),
FunctionDeclaration(
"foo",
ParameterList(Parameter("l", GenericType(NameType("IEnumerable"), NameType("string")))),
null,
BlockFunctionBody()),
FunctionDeclaration(
"foo",
GenericParameterList(GenericParameter("T")),
ParameterList(Parameter("l", GenericType(NameType("IEnumerable"), NameType("T")))),
null,
BlockFunctionBody()),
];

// import System.Collections.Generic;
// import System.Linq.Enumerable;
//
// ... foo overloads ...
//
// func main() {
// call();
//
private static SyntaxTree CreateOverloadTree(CallExpressionSyntax call) => SyntaxTree.Create(CompilationUnit(
new DeclarationSyntax[]
{
ImportDeclaration("System", "Collections", "Generic"),
ImportDeclaration("System", "Linq", "Enumerable"),
}.Concat(GetGenericListOverloads())
.Append(FunctionDeclaration(
"main",
ParameterList(),
null,
BlockFunctionBody(ExpressionStatement(call))))));

private static FunctionSymbol GetDeclaredFunctionSymbol(Compilation compilation, int index)
{
var syntaxTree = compilation.SyntaxTrees.Single();
var syntax = syntaxTree.FindInChildren<FunctionDeclarationSyntax>(index);
Debug.Assert(syntax is not null);

var semanticModel = compilation.GetSemanticModel(syntaxTree);
var symbol = semanticModel.GetDeclaredSymbol(syntax);
Debug.Assert(symbol!.Name == "foo");

return GetInternalSymbol<FunctionSymbol>(symbol);
}

private static FunctionSymbol GetCalledFunctionSymbol(Compilation compilation)
{
var syntaxTree = compilation.SyntaxTrees.Single();
var syntax = syntaxTree.FindInChildren<CallExpressionSyntax>();
Debug.Assert(syntax is not null);

var semanticModel = compilation.GetSemanticModel(syntaxTree);
var symbol = semanticModel.GetReferencedSymbol(syntax);
Debug.Assert(symbol!.Name == "foo");

return GetInternalSymbol<FunctionSymbol>(symbol);
}

[Fact]
public void ListInt32Overload()
{
// foo(List<int32>());

// Arrange
var tree = CreateOverloadTree(CallExpression(
NameExpression("foo"),
CallExpression(GenericExpression(NameExpression("List"), NameType("int32")))));

// Act
var compilation = CreateCompilation(tree);
var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 0);
var calledSymbol = GetCalledFunctionSymbol(compilation);

// Assert
Assert.Empty(compilation.Diagnostics);
Assert.Same(expectedSymbol, calledSymbol);
}

[Fact]
public void ListStringOverload()
{
// foo(List<string>());

// Arrange
var tree = CreateOverloadTree(CallExpression(
NameExpression("foo"),
CallExpression(GenericExpression(NameExpression("List"), NameType("string")))));
var compilation = CreateCompilation(tree);

// Act
var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 1);
var calledSymbol = GetCalledFunctionSymbol(compilation);

// Assert
Assert.Empty(compilation.Diagnostics);
Assert.Same(expectedSymbol, calledSymbol);
}

[Fact]
public void ListGenericOverload()
{
// foo(List<bool>());

// Arrange
var tree = CreateOverloadTree(CallExpression(
NameExpression("foo"),
CallExpression(GenericExpression(NameExpression("List"), NameType("bool")))));
var compilation = CreateCompilation(tree);

// Act
var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 2);
var calledSymbol = GetCalledFunctionSymbol(compilation);

// Assert
Assert.Empty(compilation.Diagnostics);
Assert.True(calledSymbol.IsGenericInstance);
Assert.Same(expectedSymbol, calledSymbol.GenericDefinition);
}

[Fact]
public void IEnumerableInt32Overload()
{
// foo(AsEnumerable(List<int32>()));

// Arrange
var tree = CreateOverloadTree(CallExpression(
NameExpression("foo"),
CallExpression(
NameExpression("AsEnumerable"),
CallExpression(GenericExpression(NameExpression("List"), NameType("int32"))))));
var compilation = CreateCompilation(tree);

// Act
var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 3);
var calledSymbol = GetCalledFunctionSymbol(compilation);

// Assert
Assert.Empty(compilation.Diagnostics);
Assert.Same(expectedSymbol, calledSymbol);
}

[Fact]
public void IEnumerableStringOverload()
{
// foo(AsEnumerable(List<string>()));

// Arrange
var tree = CreateOverloadTree(CallExpression(
NameExpression("foo"),
CallExpression(
NameExpression("AsEnumerable"),
CallExpression(GenericExpression(NameExpression("List"), NameType("string"))))));
var compilation = CreateCompilation(tree);

// Act
var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 4);
var calledSymbol = GetCalledFunctionSymbol(compilation);

// Assert
Assert.Empty(compilation.Diagnostics);
Assert.Same(expectedSymbol, calledSymbol);
}

[Fact]
public void IEnumerableGenericOverload()
{
// foo(AsEnumerable(List<bool>()));

// Arrange
var tree = CreateOverloadTree(CallExpression(
NameExpression("foo"),
CallExpression(
NameExpression("AsEnumerable"),
CallExpression(GenericExpression(NameExpression("List"), NameType("bool"))))));
var compilation = CreateCompilation(tree);

// Act
var expectedSymbol = GetDeclaredFunctionSymbol(compilation, 5);
var calledSymbol = GetCalledFunctionSymbol(compilation);

// Assert
Assert.Empty(compilation.Diagnostics);
Assert.True(calledSymbol.IsGenericInstance);
Assert.Same(expectedSymbol, calledSymbol.GenericDefinition);
}
}
102 changes: 0 additions & 102 deletions src/Draco.Compiler/Internal/Solver/ConstraintSolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,106 +150,4 @@ public TypeVariable AllocateTypeVariable(bool track = true)
if (track) this.typeVariables.Add(typeVar);
return typeVar;
}

/// <summary>
/// Unwraps the potential type-variable until it is a non-type-variable type.
/// </summary>
/// <param name="type">The type to unwrap.</param>
/// <returns>The task that completes when <paramref name="type"/> is subsituted as a non-type-variable.</returns>
public static async SolverTask<TypeSymbol> Substituted(TypeSymbol type)
{
while (type is TypeVariable tv) type = await tv.Substituted;
return type;
}

/// <summary>
/// Unifies two types, asserting their success.
/// </summary>
/// <param name="first">The first type to unify.</param>
/// <param name="second">The second type to unify.</param>
public static void UnifyAsserted(TypeSymbol first, TypeSymbol second)
{
if (Unify(first, second)) return;
throw new System.InvalidOperationException($"could not unify {first} and {second}");
}

/// <summary>
/// Attempts to unify two types.
/// </summary>
/// <param name="first">The first type to unify.</param>
/// <param name="second">The second type to unify.</param>
/// <returns>True, if unification was successful, false otherwise.</returns>
public static bool Unify(TypeSymbol first, TypeSymbol second)
{
first = first.Substitution;
second = second.Substitution;

// NOTE: Referential equality is OK here, we don't need to use SymbolEqualityComparer, this is unification
if (ReferenceEquals(first, second)) return true;

switch (first, second)
{
// Type variable substitution takes priority
// so it can unify with never type and error type to stop type errors from cascading
case (TypeVariable v1, TypeVariable v2):
{
// Check for circularity
// NOTE: Referential equality is OK here, we are checking for CIRCULARITY
// which is referential check
if (ReferenceEquals(v1, v2)) return true;
v1.Substitute(v2);
return true;
}
case (TypeVariable v, TypeSymbol other):
{
v.Substitute(other);
return true;
}
case (TypeSymbol other, TypeVariable v):
{
v.Substitute(other);
return true;
}

// Never type is never reached, unifies with everything
case (NeverTypeSymbol, _):
case (_, NeverTypeSymbol):
// Error type unifies with everything to avoid cascading type errors
case (ErrorTypeSymbol, _):
case (_, ErrorTypeSymbol):
return true;

case (ArrayTypeSymbol a1, ArrayTypeSymbol a2) when a1.IsGenericDefinition && a2.IsGenericDefinition:
return a1.Rank == a2.Rank;

// NOTE: Primitives are filtered out already, along with metadata types

case (FunctionTypeSymbol f1, FunctionTypeSymbol f2):
{
if (f1.Parameters.Length != f2.Parameters.Length) return false;
for (var i = 0; i < f1.Parameters.Length; ++i)
{
if (!Unify(f1.Parameters[i].Type, f2.Parameters[i].Type)) return false;
}
return Unify(f1.ReturnType, f2.ReturnType);
}

case (_, _) when first.IsGenericInstance && second.IsGenericInstance:
{
// NOTE: Generic instances might not obey referential equality
Debug.Assert(first.GenericDefinition is not null);
Debug.Assert(second.GenericDefinition is not null);
if (first.GenericArguments.Length != second.GenericArguments.Length) return false;
if (!Unify(first.GenericDefinition, second.GenericDefinition)) return false;
for (var i = 0; i < first.GenericArguments.Length; ++i)
{
if (!Unify(first.GenericArguments[i], second.GenericArguments[i])) return false;
}
return true;
}

default:
return false;
}
}
}
Loading

0 comments on commit 8dc2876

Please sign in to comment.