From bf98ede74c88edceff720fa101afd19009a3733e Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Thu, 8 Sep 2022 22:50:33 +0200 Subject: [PATCH 1/6] Don't propagate symbols in RelayCommandGenerator --- .../Extensions/ISymbolExtensions.cs | 27 +++++++++++++ .../Input/RelayCommandGenerator.cs | 39 ++++++++----------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs index 269472de..8aa4a57f 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis; namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions; @@ -64,6 +65,32 @@ public static bool HasAttributeWithFullyQualifiedName(this ISymbol symbol, strin return false; } + /// + /// Tries to get an attribute with the specified full name. + /// + /// The input instance to check. + /// The attribute name to look for. + /// The resulting attribute, if it was found. + /// Whether or not has an attribute with the specified name. + public static bool TryGetAttributeWithFullyQualifiedName(this ISymbol symbol, string name, [NotNullWhen(true)] out AttributeData? attributeData) + { + ImmutableArray attributes = symbol.GetAttributes(); + + foreach (AttributeData attribute in attributes) + { + if (attribute.AttributeClass?.HasFullyQualifiedName(name) == true) + { + attributeData = attribute; + + return true; + } + } + + attributeData = null; + + return false; + } + /// /// Calculates the effective accessibility for a given symbol. /// diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs index 06e2c7bc..8e0660f7 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs @@ -23,40 +23,33 @@ public sealed partial class RelayCommandGenerator : IIncrementalGenerator /// public void Initialize(IncrementalGeneratorInitializationContext context) { - // Get all method declarations with at least one attribute - IncrementalValuesProvider methodSymbols = + // Gather info for all annotated command methods (starting from method declarations with at least one attribute) + IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> commandInfoWithErrors = context.SyntaxProvider .CreateSyntaxProvider( static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 }, - static (context, _) => + static (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8)) { return default; } - return (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!; - }) - .Where(static item => item is not null)!; + IMethodSymbol methodSymbol = (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, token)!; - // Filter the methods using [RelayCommand] - IncrementalValuesProvider<(IMethodSymbol Symbol, AttributeData Attribute)> methodSymbolsWithAttributeData = - methodSymbols - .Select(static (item, _) => ( - item, - Attribute: item.GetAttributes().FirstOrDefault(a => a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Mvvm.Input.RelayCommandAttribute") == true))) - .Where(static item => item.Attribute is not null)!; + // Filter the methods using [RelayCommand] + if (!methodSymbol.TryGetAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.Input.RelayCommandAttribute", out AttributeData? attribute)) + { + return default; + } - // Gather info for all annotated command methods - IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> commandInfoWithErrors = - methodSymbolsWithAttributeData - .Select(static (item, _) => - { - HierarchyInfo hierarchy = HierarchyInfo.From(item.Symbol.ContainingType); - CommandInfo? commandInfo = Execute.GetInfo(item.Symbol, item.Attribute, out ImmutableArray diagnostics); + // Produce the incremental models + HierarchyInfo hierarchy = HierarchyInfo.From(methodSymbol.ContainingType); + CommandInfo? commandInfo = Execute.GetInfo(methodSymbol, attribute, out ImmutableArray diagnostics); - return (hierarchy, new Result(commandInfo, diagnostics)); - }); + return (Hierarchy: hierarchy, new Result(commandInfo, diagnostics)); + }) + .Where(static item => item.Hierarchy is not null); // Output the diagnostics context.ReportDiagnostics(commandInfoWithErrors.Select(static (item, _) => item.Info.Errors)); @@ -66,7 +59,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) commandInfoWithErrors .Where(static item => item.Info.Value is not null) .Select(static (item, _) => (item.Hierarchy, item.Info.Value!)) - .WithComparers(HierarchyInfo.Comparer.Default, CommandInfo.Comparer.Default); + .WithComparers(HierarchyInfo.Comparer.Default, CommandInfo.Comparer.Default); // Generate the commands context.RegisterSourceOutput(commandInfo, static (context, item) => From 73ebfb526928d7ab14a38e5a8fe661e057637282 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Thu, 8 Sep 2022 22:57:48 +0200 Subject: [PATCH 2/6] Don't propagate symbols in IMessengerRegisterAllGenerator --- .../IMessengerRegisterAllGenerator.cs | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs index 66508983..eff1fb58 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs @@ -22,37 +22,46 @@ public sealed partial class IMessengerRegisterAllGenerator : IIncrementalGenerat /// public void Initialize(IncrementalGeneratorInitializationContext context) { - // Get all class declarations. This pipeline step also needs to filter out duplicate recipient - // definitions (it might happen if a recipient has partial declarations). To do this, all pairs - // of class declarations and associated symbols are gathered, and then only the pair where the - // class declaration is the first syntax reference for the associated symbol is kept. - // Just like with the ObservableValidator generator, we also intentionally skip abstract types. - IncrementalValuesProvider typeSymbols = + // Get the recipient info for all target types + IncrementalValuesProvider recipientInfo = context.SyntaxProvider .CreateSyntaxProvider( static (node, _) => node is ClassDeclarationSyntax, - static (context, _) => + static (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8)) { return default; } - return (context.Node, Symbol: (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!); - }) - .Where(static item => item.Symbol is { IsAbstract: false, IsGenericType: false } && item.Node.IsFirstSyntaxDeclarationForSymbol(item.Symbol)) - .Select(static (item, _) => item.Symbol); + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, token)!; - // Get the target IRecipient interfaces and filter out other types - IncrementalValuesProvider<(INamedTypeSymbol Type, ImmutableArray Interfaces)> typeAndInterfaceSymbols = - typeSymbols - .Select(static (item, _) => (item, Interfaces: Execute.GetInterfaces(item))) - .Where(static item => !item.Interfaces.IsEmpty); + // The type must be a non-abstract, non-generic type (just like with the ObservableValidator generator) + if (typeSymbol is not { IsAbstract: false, IsGenericType: false }) + { + return default; + } - // Get the recipient info for all target types - IncrementalValuesProvider recipientInfo = - typeAndInterfaceSymbols - .Select(static (item, _) => Execute.GetInfo(item.Type, item.Interfaces)) + // This pipeline step also needs to filter out duplicate recipient definitions (it might happen if a + // recipient has partial declarations). To do this, all pairs of class declarations and associated + // symbols are gathered, and then only the pair where the class declaration is the first syntax + // reference for the associated symbol is kept. + if (!context.Node.IsFirstSyntaxDeclarationForSymbol(typeSymbol)) + { + return default; + } + + ImmutableArray interfaceSymbols = Execute.GetInterfaces(typeSymbol); + + // Check that the type implements at least one IRecipient interface + if (interfaceSymbols.IsEmpty) + { + return default; + } + + return Execute.GetInfo(typeSymbol, interfaceSymbols); + }) + .Where(static item => item is not null)! .WithComparer(RecipientInfo.Comparer.Default); // Check whether the header file is needed @@ -68,7 +77,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Gather the conditional flag and attribute availability IncrementalValueProvider<(bool IsHeaderFileNeeded, bool IsDynamicallyAccessedMembersAttributeAvailable)> headerFileInfo = - isHeaderFileNeeded.Combine(isDynamicallyAccessedMembersAttributeAvailable); + isHeaderFileNeeded + .Combine(isDynamicallyAccessedMembersAttributeAvailable); // Generate the header file with the attributes context.RegisterConditionalImplementationSourceOutput(headerFileInfo, static (context, item) => From f1f3f9c6c60b37670e875a0bd49fdbee8a5ea7a7 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sat, 10 Sep 2022 03:27:09 +0200 Subject: [PATCH 3/6] Don't propagate symbols in ObservablePropertyGenerator --- .../ObservablePropertyGenerator.cs | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs index 901f2a71..c9926444 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs @@ -24,38 +24,33 @@ public sealed partial class ObservablePropertyGenerator : IIncrementalGenerator /// public void Initialize(IncrementalGeneratorInitializationContext context) { - // Get all field declarations with at least one attribute - IncrementalValuesProvider fieldSymbols = + // Gather info for all annotated fields + IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> propertyInfoWithErrors = context.SyntaxProvider .CreateSyntaxProvider( - static (node, _) => node is FieldDeclarationSyntax { Parent: ClassDeclarationSyntax or RecordDeclarationSyntax, AttributeLists.Count: > 0 }, - static (context, _) => + static (node, _) => node is VariableDeclaratorSyntax { Parent: VariableDeclarationSyntax { Parent: FieldDeclarationSyntax { Parent: ClassDeclarationSyntax or RecordDeclarationSyntax, AttributeLists.Count: > 0 } } }, + static (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8)) { return default; } - return ((FieldDeclarationSyntax)context.Node).Declaration.Variables.Select(v => (IFieldSymbol)context.SemanticModel.GetDeclaredSymbol(v)!); - }) - .Where(static items => items is not null) - .SelectMany(static (item, _) => item!)!; + IFieldSymbol fieldSymbol = (IFieldSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, token)!; - // Filter the fields using [ObservableProperty] - IncrementalValuesProvider fieldSymbolsWithAttribute = - fieldSymbols - .Where(static item => item.HasAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservablePropertyAttribute")); + // Filter the fields using [ObservableProperty] + if (!fieldSymbol.HasAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservablePropertyAttribute")) + { + return default; + } - // Gather info for all annotated fields - IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> propertyInfoWithErrors = - fieldSymbolsWithAttribute - .Select(static (item, _) => - { - HierarchyInfo hierarchy = HierarchyInfo.From(item.ContainingType); - PropertyInfo? propertyInfo = Execute.TryGetInfo(item, out ImmutableArray diagnostics); + // Produce the incremental models + HierarchyInfo hierarchy = HierarchyInfo.From(fieldSymbol.ContainingType); + PropertyInfo? propertyInfo = Execute.TryGetInfo(fieldSymbol, out ImmutableArray diagnostics); - return (hierarchy, new Result(propertyInfo, diagnostics)); - }); + return (Hierarchy: hierarchy, new Result(propertyInfo, diagnostics)); + }) + .Where(static item => item.Hierarchy is not null); // Output the diagnostics context.ReportDiagnostics(propertyInfoWithErrors.Select(static (item, _) => item.Info.Errors)); From 5f90862efca37651928f7e23fd6beb54c6178f08 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sat, 10 Sep 2022 03:31:39 +0200 Subject: [PATCH 4/6] Don't propagate symbols in ObservableValidatorValidateAllPropertiesGenerator --- ...ValidatorValidateAllPropertiesGenerator.cs | 42 ++++++++++++------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs index 320db7f0..3fdbf8aa 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs @@ -21,31 +21,43 @@ public sealed partial class ObservableValidatorValidateAllPropertiesGenerator : /// public void Initialize(IncrementalGeneratorInitializationContext context) { - // Get all class declarations. We intentionally skip generating code for abstract types, as that would never be used. - // The methods that are generated by this generator are retrieved through reflection using the type of the invoking - // instance as discriminator, which means a type that is abstract could never be used (since it couldn't be instantiated). - IncrementalValuesProvider typeSymbols = + // Get the types that inherit from ObservableValidator and gather their info + IncrementalValuesProvider validationInfo = context.SyntaxProvider .CreateSyntaxProvider( static (node, _) => node is ClassDeclarationSyntax, - static (context, _) => + static (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8)) { return default; } - return (context.Node, Symbol: (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!); - }) - .Where(static item => item.Symbol is { IsAbstract: false, IsGenericType: false } && item.Node.IsFirstSyntaxDeclarationForSymbol(item.Symbol)) - .Select(static (item, _) => item.Symbol); + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, token)!; - // Get the types that inherit from ObservableValidator and gather their info - IncrementalValuesProvider validationInfo = - typeSymbols - .Where(Execute.IsObservableValidator) - .Select(static (item, _) => Execute.GetInfo(item)) - .WithComparer(ValidationInfo.Comparer.Default); + // Skip generating code for abstract types, as that would never be used. The methods that are generated by + // this generator are retrieved through reflection using the type of the invoking instance as discriminator, + // which means a type that is abstract could never be used (since it couldn't be instantiated). + if (typeSymbol is not { IsAbstract: false, IsGenericType: false }) + { + return default; + } + + // Just like in IMessengerRegisterAllGenerator, only select the first declaration for this type symbol + if (!context.Node.IsFirstSyntaxDeclarationForSymbol(typeSymbol)) + { + return default; + } + + // Only select types inheriting from ObservableValidator + if (!Execute.IsObservableValidator(typeSymbol)) + { + return default; + } + + return Execute.GetInfo(typeSymbol); + }) + .Where(static item => item is not null)!; // Check whether the header file is needed IncrementalValueProvider isHeaderFileNeeded = From 27cd6b47a65260746875e2a1269d2a9505b149c7 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sat, 10 Sep 2022 13:35:14 +0200 Subject: [PATCH 5/6] Don't propagate symbols in TransitiveMembersGenerator --- .../INotifyPropertyChangedGenerator.cs | 34 +++------ .../ObservableObjectGenerator.cs | 25 ++----- .../ObservableRecipientGenerator.cs | 72 +++++++----------- .../TransitiveMembersGenerator.cs | 75 ++++++++----------- 4 files changed, 73 insertions(+), 133 deletions(-) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs index 2831b6a1..37882bc0 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs @@ -28,33 +28,18 @@ public INotifyPropertyChangedGenerator() } /// - protected override IncrementalValuesProvider<(INamedTypeSymbol Symbol, INotifyPropertyChangedInfo Info)> GetInfo( - IncrementalGeneratorInitializationContext context, - IncrementalValuesProvider<(INamedTypeSymbol Symbol, AttributeData AttributeData)> source) - { - static INotifyPropertyChangedInfo GetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData) - { - bool includeAdditionalHelperMethods = attributeData.GetNamedArgument("IncludeAdditionalHelperMethods", true); - - return new(includeAdditionalHelperMethods); - } - - return source.Select(static (item, _) => (item.Symbol, GetInfo(item.Symbol, item.AttributeData))); - } - - /// - protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, INotifyPropertyChangedInfo info, out ImmutableArray diagnostics) + protected override INotifyPropertyChangedInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) { ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + INotifyPropertyChangedInfo? info = null; + // Check if the type already implements INotifyPropertyChanged if (typeSymbol.AllInterfaces.Any(i => i.HasFullyQualifiedName("global::System.ComponentModel.INotifyPropertyChanged"))) { builder.Add(DuplicateINotifyPropertyChangedInterfaceForINotifyPropertyChangedAttributeError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } // Check if the type uses [INotifyPropertyChanged] or [ObservableObject] already (in the type hierarchy too) @@ -63,14 +48,17 @@ protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, INotifyP { builder.Add(InvalidAttributeCombinationForINotifyPropertyChangedAttributeError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } + bool includeAdditionalHelperMethods = attributeData.GetNamedArgument("IncludeAdditionalHelperMethods", true); + + info = new INotifyPropertyChangedInfo(includeAdditionalHelperMethods); + + End: diagnostics = builder.ToImmutable(); - return true; + return info; } /// diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs index 114cb937..c101ad09 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs @@ -27,15 +27,7 @@ public ObservableObjectGenerator() } /// - protected override IncrementalValuesProvider<(INamedTypeSymbol Symbol, object? Info)> GetInfo( - IncrementalGeneratorInitializationContext context, - IncrementalValuesProvider<(INamedTypeSymbol Symbol, AttributeData AttributeData)> source) - { - return source.Select(static (item, _) => (item.Symbol, (object?)null)); - } - - /// - protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, object? info, out ImmutableArray diagnostics) + protected override object? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) { ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); @@ -44,9 +36,7 @@ protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, object? { builder.Add(DuplicateINotifyPropertyChangedInterfaceForObservableObjectAttributeError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } // ...or INotifyPropertyChanging @@ -54,9 +44,7 @@ protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, object? { builder.Add(DuplicateINotifyPropertyChangingInterfaceForObservableObjectAttributeError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } // Check if the type uses [INotifyPropertyChanged] or [ObservableObject] already (in the type hierarchy too) @@ -65,14 +53,13 @@ protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, object? { builder.Add(InvalidAttributeCombinationForObservableObjectAttributeError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } + End: diagnostics = builder.ToImmutable(); - return true; + return null; } /// diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs index e16a161d..af217683 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs @@ -30,53 +30,18 @@ public ObservableRecipientGenerator() } /// - protected override IncrementalValuesProvider<(INamedTypeSymbol Symbol, ObservableRecipientInfo Info)> GetInfo( - IncrementalGeneratorInitializationContext context, - IncrementalValuesProvider<(INamedTypeSymbol Symbol, AttributeData AttributeData)> source) - { - static ObservableRecipientInfo GetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, bool isRequiresUnreferencedCodeAttributeAvailable) - { - string typeName = typeSymbol.Name; - bool hasExplicitConstructors = !(typeSymbol.InstanceConstructors.Length == 1 && typeSymbol.InstanceConstructors[0] is { Parameters.IsEmpty: true, IsImplicitlyDeclared: true }); - bool isAbstract = typeSymbol.IsAbstract; - bool isObservableValidator = typeSymbol.InheritsFromFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableValidator"); - bool hasOnActivatedMethod = typeSymbol.GetMembers().Any(m => m is IMethodSymbol { Parameters.IsEmpty: true, Name: "OnActivated" }); - bool hasOnDeactivatedMethod = typeSymbol.GetMembers().Any(m => m is IMethodSymbol { Parameters.IsEmpty: true, Name: "OnDeactivated" }); - - return new( - typeName, - hasExplicitConstructors, - isAbstract, - isObservableValidator, - isRequiresUnreferencedCodeAttributeAvailable, - hasOnActivatedMethod, - hasOnDeactivatedMethod); - } - - // Check whether [RequiresUnreferencedCode] is available - IncrementalValueProvider isRequiresUnreferencedCodeAttributeAvailable = - context.CompilationProvider - .Select(static (item, _) => item.GetTypeByMetadataName("System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute") is { DeclaredAccessibility: Accessibility.Public }); - - return - source - .Combine(isRequiresUnreferencedCodeAttributeAvailable) - .Select(static (item, _) => (item.Left.Symbol, GetInfo(item.Left.Symbol, item.Left.AttributeData, item.Right))); - } - - /// - protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, ObservableRecipientInfo info, out ImmutableArray diagnostics) + protected override ObservableRecipientInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) { ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + ObservableRecipientInfo? info = null; + // Check if the type already inherits from ObservableRecipient if (typeSymbol.InheritsFromFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableRecipient")) { builder.Add(DuplicateObservableRecipientError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } // Check if the type already inherits [ObservableRecipient] @@ -84,9 +49,7 @@ protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, Observab { builder.Add(InvalidAttributeCombinationForObservableRecipientAttributeError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } // In order to use [ObservableRecipient], the target type needs to inherit from ObservableObject, @@ -99,14 +62,31 @@ protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, Observab { builder.Add(MissingBaseObservableObjectFunctionalityError, typeSymbol, typeSymbol); - diagnostics = builder.ToImmutable(); - - return false; + goto End; } + // Gather all necessary info to propagate down the pipeline + string typeName = typeSymbol.Name; + bool hasExplicitConstructors = !(typeSymbol.InstanceConstructors.Length == 1 && typeSymbol.InstanceConstructors[0] is { Parameters.IsEmpty: true, IsImplicitlyDeclared: true }); + bool isAbstract = typeSymbol.IsAbstract; + bool isObservableValidator = typeSymbol.InheritsFromFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableValidator"); + bool isRequiresUnreferencedCodeAttributeAvailable = compilation.HasAccessibleTypeWithMetadataName("System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute"); + bool hasOnActivatedMethod = typeSymbol.GetMembers().Any(m => m is IMethodSymbol { Parameters.IsEmpty: true, Name: "OnActivated" }); + bool hasOnDeactivatedMethod = typeSymbol.GetMembers().Any(m => m is IMethodSymbol { Parameters.IsEmpty: true, Name: "OnDeactivated" }); + + info = new ObservableRecipientInfo( + typeName, + hasExplicitConstructors, + isAbstract, + isObservableValidator, + isRequiresUnreferencedCodeAttributeAvailable, + hasOnActivatedMethod, + hasOnDeactivatedMethod); + + End: diagnostics = builder.ToImmutable(); - return true; + return info; } /// diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs index 8b0a01a0..95fc8967 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs @@ -11,7 +11,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; namespace CommunityToolkit.Mvvm.SourceGenerators; @@ -70,46 +69,40 @@ private protected TransitiveMembersGenerator(string attributeType, IEqualityComp /// public void Initialize(IncrementalGeneratorInitializationContext context) { - // Get all class declarations - IncrementalValuesProvider typeSymbols = + // Gather all generation info, and any diagnostics + IncrementalValuesProvider> generationInfoWithErrors = context.SyntaxProvider .CreateSyntaxProvider( static (node, _) => node is ClassDeclarationSyntax { AttributeLists.Count: > 0 }, - static (context, _) => + (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8)) { return default; } - return (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!; - }) - .Where(static item => item is not null)!; + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, token)!; - // Filter the types with the target attribute - IncrementalValuesProvider<(INamedTypeSymbol Symbol, AttributeData AttributeData)> typeSymbolsWithAttributeData = - typeSymbols - .Select((item, _) => ( - Symbol: item, - Attribute: item.GetAttributes().FirstOrDefault(a => a.AttributeClass?.HasFullyQualifiedName(this.attributeType) == true))) - .Where(static item => item.Attribute is not null)!; + // Filter the types with the target attribute + if (!typeSymbol.TryGetAttributeWithFullyQualifiedName(this.attributeType, out AttributeData? attributeData)) + { + return default; + } - // Transform the input data - IncrementalValuesProvider<(INamedTypeSymbol Symbol, TInfo Info)> typeSymbolsWithInfo = GetInfo(context, typeSymbolsWithAttributeData); + // Gather all generation info, and any diagnostics + TInfo? info = ValidateTargetTypeAndGetInfo(typeSymbol, attributeData, context.SemanticModel.Compilation, out ImmutableArray diagnostics); - // Gather all generation info, and any diagnostics - IncrementalValuesProvider> generationInfoWithErrors = - typeSymbolsWithInfo.Select((item, _) => - { - if (ValidateTargetType(item.Symbol, item.Info, out ImmutableArray diagnostics)) - { - return new Result<(HierarchyInfo, bool, TInfo)>( - (HierarchyInfo.From(item.Symbol), item.Symbol.IsSealed, item.Info), - ImmutableArray.Empty); - } + // If there are any diagnostics, there's no need to compute the hierarchy info at all, just return them + if (diagnostics.Length > 0) + { + return new Result<(HierarchyInfo, bool, TInfo?)>(default, diagnostics); + } + + HierarchyInfo hierarchy = HierarchyInfo.From(typeSymbol); - return new Result<(HierarchyInfo, bool, TInfo)>(default, diagnostics); - }); + return new Result<(HierarchyInfo, bool, TInfo?)>((hierarchy, typeSymbol.IsSealed, info), diagnostics); + }) + .Where(static item => item is not null)!; // Emit the diagnostic, if needed context.ReportDiagnostics(generationInfoWithErrors.Select(static (item, _) => item.Errors)); @@ -118,7 +111,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) IncrementalValuesProvider<(HierarchyInfo Hierarchy, bool IsSealed, TInfo Info)> generationInfo = generationInfoWithErrors .Where(static item => item.Errors.IsEmpty) - .Select(static (item, _) => item.Value) + .Select(static (item, _) => item.Value)! .WithComparers(HierarchyInfo.Comparer.Default, EqualityComparer.Default, this.comparer); // Generate the required members @@ -133,23 +126,15 @@ public void Initialize(IncrementalGeneratorInitializationContext context) } /// - /// Gathers info from a source input. + /// Validates the target type being processes, gets the info if possible and produces all necessary diagnostics. /// - /// The instance in use. - /// The source input. - /// A transformed instance with the gathered data. - protected abstract IncrementalValuesProvider<(INamedTypeSymbol Symbol, TInfo Info)> GetInfo( - IncrementalGeneratorInitializationContext context, - IncrementalValuesProvider<(INamedTypeSymbol Symbol, AttributeData AttributeData)> source); - - /// - /// Validates a target type being processed. - /// - /// The instance for the target type. - /// The instance with the current processing info. - /// The resulting diagnostics from the processing operation. - /// Whether or not the target type is valid and can be processed normally. - protected abstract bool ValidateTargetType(INamedTypeSymbol typeSymbol, TInfo info, out ImmutableArray diagnostics); + /// The instance currently being processed. + /// The instance for the attribute used over . + /// The compilation that belongs to. + /// The resulting diagnostics, if any. + /// The extracted info for the current type, if possible. + /// If is empty, the returned info will always be ignored and no sources will be produced. + protected abstract TInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics); /// /// Filters the nodes to generate from the input parsed tree. From b6dd39b4d147d639e49aeaa460e9be31a25d2aa6 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sat, 10 Sep 2022 13:43:24 +0200 Subject: [PATCH 6/6] Improve syntactical filtering for messaging/validator generators --- ...ValidatorValidateAllPropertiesGenerator.cs | 2 +- .../TypeDeclarationSyntaxExtensions.cs | 41 +++++++++++++++++++ .../IMessengerRegisterAllGenerator.cs | 2 +- 3 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 CommunityToolkit.Mvvm.SourceGenerators/Extensions/TypeDeclarationSyntaxExtensions.cs diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs index 3fdbf8aa..5417571d 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs @@ -25,7 +25,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) IncrementalValuesProvider validationInfo = context.SyntaxProvider .CreateSyntaxProvider( - static (node, _) => node is ClassDeclarationSyntax, + static (node, _) => node is ClassDeclarationSyntax classDeclaration && classDeclaration.HasOrPotentiallyHasBaseTypes(), static (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8)) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/TypeDeclarationSyntaxExtensions.cs b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/TypeDeclarationSyntaxExtensions.cs new file mode 100644 index 00000000..32148a5d --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/TypeDeclarationSyntaxExtensions.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions; + +/// +/// Extension methods for the type. +/// +internal static class TypeDeclarationSyntaxExtensions +{ + /// + /// Checks whether a given has or could possibly have any base types, using only syntax. + /// + /// The input instance to check. + /// Whether has or could possibly have any base types. + public static bool HasOrPotentiallyHasBaseTypes(this TypeDeclarationSyntax typeDeclaration) + { + // If the base types list is not empty, the type can definitely has implemented interfaces + if (typeDeclaration.BaseList is { Types.Count: > 0 }) + { + return true; + } + + // If the base types list is empty, check if the type is partial. If it is, it means + // that there could be another partial declaration with a non-empty base types list. + foreach (SyntaxToken modifier in typeDeclaration.Modifiers) + { + if (modifier.IsKind(SyntaxKind.PartialKeyword)) + { + return true; + } + } + + return false; + } +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs index eff1fb58..be66b42c 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs @@ -26,7 +26,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) IncrementalValuesProvider recipientInfo = context.SyntaxProvider .CreateSyntaxProvider( - static (node, _) => node is ClassDeclarationSyntax, + static (node, _) => node is ClassDeclarationSyntax classDeclaration && classDeclaration.HasOrPotentiallyHasBaseTypes(), static (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8))