Просмотр исходного кода

Add default auth policy to generated openapi spec (#11181)

Cody Robibero 1 год назад
Родитель
Сommit
d9e35a969f
1 измененных файлов с 73 добавлено и 59 удалено
  1. 73 59
      Jellyfin.Server/Filters/SecurityRequirementsOperationFilter.cs

+ 73 - 59
Jellyfin.Server/Filters/SecurityRequirementsOperationFilter.cs

@@ -1,91 +1,105 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using Jellyfin.Api.Auth.DefaultAuthorizationPolicy;
 using Jellyfin.Api.Constants;
+using Jellyfin.Extensions;
 using Microsoft.AspNetCore.Authorization;
 using Microsoft.OpenApi.Models;
 using Swashbuckle.AspNetCore.SwaggerGen;
 
-namespace Jellyfin.Server.Filters
+namespace Jellyfin.Server.Filters;
+
+/// <summary>
+/// Security requirement operation filter.
+/// </summary>
+public class SecurityRequirementsOperationFilter : IOperationFilter
 {
+    private const string DefaultAuthPolicy = "DefaultAuthorization";
+    private static readonly Type _attributeType = typeof(AuthorizeAttribute);
+
+    private readonly IAuthorizationPolicyProvider _authorizationPolicyProvider;
+
     /// <summary>
-    /// Security requirement operation filter.
+    /// Initializes a new instance of the <see cref="SecurityRequirementsOperationFilter"/> class.
     /// </summary>
-    public class SecurityRequirementsOperationFilter : IOperationFilter
+    /// <param name="authorizationPolicyProvider">The authorization policy provider.</param>
+    public SecurityRequirementsOperationFilter(IAuthorizationPolicyProvider authorizationPolicyProvider)
     {
-        /// <inheritdoc />
-        public void Apply(OpenApiOperation operation, OperationFilterContext context)
-        {
-            var requiredScopes = new List<string>();
+        _authorizationPolicyProvider = authorizationPolicyProvider;
+    }
 
-            var requiresAuth = false;
-            // Add all method scopes.
-            foreach (var attribute in context.MethodInfo.GetCustomAttributes(true))
-            {
-                if (attribute is not AuthorizeAttribute authorizeAttribute)
-                {
-                    continue;
-                }
+    /// <inheritdoc />
+    public void Apply(OpenApiOperation operation, OperationFilterContext context)
+    {
+        var requiredScopes = new List<string>();
 
-                requiresAuth = true;
-                if (authorizeAttribute.Policy is not null
-                    && !requiredScopes.Contains(authorizeAttribute.Policy, StringComparer.Ordinal))
-                {
-                    requiredScopes.Add(authorizeAttribute.Policy);
-                }
+        var requiresAuth = false;
+        // Add all method scopes.
+        foreach (var authorizeAttribute in context.MethodInfo.GetCustomAttributes(_attributeType, true).Cast<AuthorizeAttribute>())
+        {
+            requiresAuth = true;
+            var policy = authorizeAttribute.Policy ?? DefaultAuthPolicy;
+            if (!requiredScopes.Contains(policy, StringComparer.Ordinal))
+            {
+                requiredScopes.Add(policy);
             }
+        }
 
-            // Add controller scopes if any.
-            var controllerAttributes = context.MethodInfo.DeclaringType?.GetCustomAttributes(true);
-            if (controllerAttributes is not null)
+        // Add controller scopes if any.
+        var controllerAttributes = context.MethodInfo.DeclaringType?.GetCustomAttributes(_attributeType, true).Cast<AuthorizeAttribute>();
+        if (controllerAttributes is not null)
+        {
+            foreach (var authorizeAttribute in controllerAttributes)
             {
-                foreach (var attribute in controllerAttributes)
+                requiresAuth = true;
+                var policy = authorizeAttribute.Policy ?? DefaultAuthPolicy;
+                if (!requiredScopes.Contains(policy, StringComparer.Ordinal))
                 {
-                    if (attribute is not AuthorizeAttribute authorizeAttribute)
-                    {
-                        continue;
-                    }
-
-                    requiresAuth = true;
-                    if (authorizeAttribute.Policy is not null
-                        && !requiredScopes.Contains(authorizeAttribute.Policy, StringComparer.Ordinal))
-                    {
-                        requiredScopes.Add(authorizeAttribute.Policy);
-                    }
+                    requiredScopes.Add(policy);
                 }
             }
+        }
 
-            if (!requiresAuth)
-            {
-                return;
-            }
+        if (!requiresAuth)
+        {
+            return;
+        }
 
-            if (!operation.Responses.ContainsKey("401"))
-            {
-                operation.Responses.Add("401", new OpenApiResponse { Description = "Unauthorized" });
-            }
+        if (!operation.Responses.ContainsKey("401"))
+        {
+            operation.Responses.Add("401", new OpenApiResponse { Description = "Unauthorized" });
+        }
 
-            if (!operation.Responses.ContainsKey("403"))
-            {
-                operation.Responses.Add("403", new OpenApiResponse { Description = "Forbidden" });
-            }
+        if (!operation.Responses.ContainsKey("403"))
+        {
+            operation.Responses.Add("403", new OpenApiResponse { Description = "Forbidden" });
+        }
 
-            var scheme = new OpenApiSecurityScheme
+        var scheme = new OpenApiSecurityScheme
+        {
+            Reference = new OpenApiReference
             {
-                Reference = new OpenApiReference
-                {
-                    Type = ReferenceType.SecurityScheme,
-                    Id = AuthenticationSchemes.CustomAuthentication
-                }
-            };
+                Type = ReferenceType.SecurityScheme,
+                Id = AuthenticationSchemes.CustomAuthentication
+            },
+        };
 
-            operation.Security = new List<OpenApiSecurityRequirement>
+        // Add DefaultAuthorization scope to any endpoint that has a policy with a requirement that is a subset of DefaultAuthorization.
+        if (!requiredScopes.Contains(DefaultAuthPolicy.AsSpan(), StringComparison.Ordinal))
+        {
+            foreach (var scope in requiredScopes)
             {
-                new OpenApiSecurityRequirement
+                var authorizationPolicy = _authorizationPolicyProvider.GetPolicyAsync(scope).GetAwaiter().GetResult();
+                if (authorizationPolicy is not null
+                    && authorizationPolicy.Requirements.Any(r => r is DefaultAuthorizationRequirement))
                 {
-                    [scheme] = requiredScopes
+                    requiredScopes.Add(DefaultAuthPolicy);
+                    break;
                 }
-            };
+            }
         }
+
+        operation.Security = [new OpenApiSecurityRequirement { [scheme] = requiredScopes }];
     }
 }