diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 31ae036..0cdb7f6 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -14,6 +14,10 @@ public sealed class MinimalApiGenerator : IIncrementalGenerator { private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; + private static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); + private static readonly string[] AspNetCoreHttpNamespaceParts = new[] { "Microsoft", "AspNetCore", "Http" }; + private static readonly string[] AspNetCoreAuthorizationNamespaceParts = new[] { "Microsoft", "AspNetCore", "Authorization" }; + private static readonly string[] AspNetCoreRoutingNamespaceParts = new[] { "Microsoft", "AspNetCore", "Routing" }; private static readonly ImmutableArray HttpAttributeDefinitions = [ @@ -47,7 +51,7 @@ public sealed class MinimalApiGenerator : IIncrementalGenerator private const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; private const string DisableAntiforgeryAttributeHint = $"{DisableAntiforgeryAttributeFullyQualifiedName}.gs.cs"; - private const string AllowAnonymousAttributeFullyQualifiedName = "Microsoft.AspNetCore.Authorization.AllowAnonymousAttribute"; + private const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; private const string AcceptsAttributeName = "AcceptsAttribute"; private const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; @@ -716,96 +720,108 @@ ref bool hasRequireAuthorizationAttribute if (attributeClass is null) continue; - var fullyQualifiedName = attributeClass.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - - if (IsGeneratedAttribute(fullyQualifiedName, AcceptsAttributeName)) + if (IsGeneratedAttribute(attributeClass, AcceptsAttributeName)) { TryAddAcceptsMetadata(attribute, attributeClass, ref accepts); continue; } - if (IsGeneratedAttribute(fullyQualifiedName, ProducesResponseAttributeName)) + if (IsGeneratedAttribute(attributeClass, ProducesResponseAttributeName)) { TryAddProducesMetadata(attribute, attributeClass, ref produces); continue; } - switch (fullyQualifiedName) + if (IsAttribute(attributeClass, "TagsAttribute", AspNetCoreHttpNamespaceParts)) { - case "global::Microsoft.AspNetCore.Http.TagsAttribute": - if (attribute.ConstructorArguments.Length > 0) - { - var arg = attribute.ConstructorArguments[0]; - if (arg.Values.Length > 0) - { - var values = arg.Values - .Select(v => v.Value as string) - .Where(s => !string.IsNullOrWhiteSpace(s)) - .Select(s => s!.Trim()); - - MergeInto(ref tags, values); - } - } - break; - case $"global::{RequireAuthorizationAttributeFullyQualifiedName}": - requireAuthorization = true; - hasRequireAuthorizationAttribute = true; - if (attribute.ConstructorArguments.Length == 1) + if (attribute.ConstructorArguments.Length > 0) + { + var arg = attribute.ConstructorArguments[0]; + if (arg.Values.Length > 0) { - var arg = attribute.ConstructorArguments[0]; - if (arg.Values.Length > 0) - { - var values = arg.Values - .Select(v => v.Value as string) - .Where(s => !string.IsNullOrWhiteSpace(s)) - .Select(s => s!.Trim()); - - MergeInto(ref authorizationPolicies, values); - } + var values = arg.Values + .Select(v => v.Value as string) + .Where(s => !string.IsNullOrWhiteSpace(s)) + .Select(s => s!.Trim()); + + MergeInto(ref tags, values); } - break; - case $"global::{DisableAntiforgeryAttributeFullyQualifiedName}": - disableAntiforgery = true; - break; - case $"global::{AllowAnonymousAttributeFullyQualifiedName}": - allowAnonymous = true; - hasAllowAnonymousAttribute = true; - break; - case "global::Microsoft.AspNetCore.Routing.ExcludeFromDescriptionAttribute": - excludeFromDescription = true; - break; - case $"global::{ProducesProblemAttributeFullyQualifiedName}": - { - var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesProblemStatusCode - ? producesProblemStatusCode - : 500; - var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 - ? GetStringArrayValues(attribute.ConstructorArguments[2]) - : null; - - var producesProblemList = producesProblem ??= []; - producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); - break; } - case $"global::{ProducesValidationProblemAttributeFullyQualifiedName}": + + continue; + } + + if (IsGeneratedAttribute(attributeClass, RequireAuthorizationAttributeName)) + { + requireAuthorization = true; + hasRequireAuthorizationAttribute = true; + if (attribute.ConstructorArguments.Length == 1) { - var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesValidationProblemStatusCode - ? producesValidationProblemStatusCode - : 400; - var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 - ? GetStringArrayValues(attribute.ConstructorArguments[2]) - : null; - - var producesValidationProblemList = producesValidationProblem ??= []; - producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); - break; + var arg = attribute.ConstructorArguments[0]; + if (arg.Values.Length > 0) + { + var values = arg.Values + .Select(v => v.Value as string) + .Where(s => !string.IsNullOrWhiteSpace(s)) + .Select(s => s!.Trim()); + + MergeInto(ref authorizationPolicies, values); + } } + + continue; + } + + if (IsGeneratedAttribute(attributeClass, DisableAntiforgeryAttributeName)) + { + disableAntiforgery = true; + continue; + } + + if (IsAttribute(attributeClass, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) + { + allowAnonymous = true; + hasAllowAnonymousAttribute = true; + continue; + } + + if (IsAttribute(attributeClass, "ExcludeFromDescriptionAttribute", AspNetCoreRoutingNamespaceParts)) + { + excludeFromDescription = true; + continue; + } + + if (IsGeneratedAttribute(attributeClass, ProducesProblemAttributeName)) + { + var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesProblemStatusCode + ? producesProblemStatusCode + : 500; + var contentType = attribute.ConstructorArguments.Length > 1 + ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) + : null; + var additionalContentTypes = attribute.ConstructorArguments.Length > 2 + ? GetStringArrayValues(attribute.ConstructorArguments[2]) + : null; + + var producesProblemList = producesProblem ??= []; + producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); + continue; + } + + if (IsGeneratedAttribute(attributeClass, ProducesValidationProblemAttributeName)) + { + var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesValidationProblemStatusCode + ? producesValidationProblemStatusCode + : 400; + var contentType = attribute.ConstructorArguments.Length > 1 + ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) + : null; + var additionalContentTypes = attribute.ConstructorArguments.Length > 2 + ? GetStringArrayValues(attribute.ConstructorArguments[2]) + : null; + + var producesValidationProblemList = producesValidationProblem ??= []; + producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); } } } @@ -846,10 +862,29 @@ private static string NormalizeRequiredContentType(string? contentType, string d return builder.Count > 0 ? builder.ToEquatableImmutable() : null; } - private static bool IsGeneratedAttribute(string fullyQualifiedName, string attributeName) + private static bool IsGeneratedAttribute(INamedTypeSymbol attributeClass, string attributeName) { - var prefix = $"global::{AttributesNamespace}.{attributeName}"; - return fullyQualifiedName.StartsWith(prefix, StringComparison.Ordinal); + var definition = attributeClass.OriginalDefinition; + return definition.Name == attributeName && IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts); + } + + private static bool IsAttribute(INamedTypeSymbol attributeClass, string attributeName, string[] namespaceParts) + { + var definition = attributeClass.OriginalDefinition; + return definition.Name == attributeName && IsInNamespace(definition.ContainingNamespace, namespaceParts); + } + + private static bool IsInNamespace(INamespaceSymbol? namespaceSymbol, string[] namespaceParts) + { + for (var i = namespaceParts.Length - 1; i >= 0; i--) + { + if (namespaceSymbol is null || namespaceSymbol.Name != namespaceParts[i]) + return false; + + namespaceSymbol = namespaceSymbol.ContainingNamespace; + } + + return namespaceSymbol is null || namespaceSymbol.IsGlobalNamespace; } private static void TryAddAcceptsMetadata(