diff --git a/PropelAuth/PropelAuthExtensions.cs b/PropelAuth/PropelAuthExtensions.cs index 2b9df80..1850fb4 100644 --- a/PropelAuth/PropelAuthExtensions.cs +++ b/PropelAuth/PropelAuthExtensions.cs @@ -77,11 +77,17 @@ private static void ConfigureAuthentication(IServiceCollection services, PropelA { var authBuilder = services.AddAuthentication(authOptions => { - if (options.OAuthOptions != null) + if (options.OAuthOptions is {AllowBearerTokenAuth: true}) + { + authOptions.DefaultAuthenticateScheme = "OAuthOrBearer"; + authOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; + authOptions.DefaultChallengeScheme = "OAuthOrBearer"; + } + else if (options.OAuthOptions != null) { authOptions.DefaultAuthenticateScheme = CookieAuthenticationDefaults.AuthenticationScheme; authOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; - authOptions.DefaultChallengeScheme = "PropelAuth"; + authOptions.DefaultChallengeScheme = "OAuth"; } else { @@ -90,6 +96,23 @@ private static void ConfigureAuthentication(IServiceCollection services, PropelA } }); + if (options.OAuthOptions is {AllowBearerTokenAuth: true}) + { + authBuilder.AddPolicyScheme("OAuthOrBearer", "OAuth or Bearer", policyOptions => + { + policyOptions.ForwardDefaultSelector = context => + { + if (context.Request.Headers.ContainsKey("Authorization")) + { + return JwtBearerDefaults.AuthenticationScheme; + } + + return "OAuth"; + }; + }); + } + + if (options.OAuthOptions == null || options.OAuthOptions.AllowBearerTokenAuth == true) { authBuilder.AddJwtBearer(jwtOptions => @@ -104,7 +127,8 @@ private static void ConfigureAuthentication(IServiceCollection services, PropelA }; }); } - else + + if (options.OAuthOptions != null) { authBuilder .AddCookie(cookieOptions => @@ -114,7 +138,7 @@ private static void ConfigureAuthentication(IServiceCollection services, PropelA cookieOptions.Cookie.SecurePolicy = CookieSecurePolicy.Always; cookieOptions.SlidingExpiration = true; }) - .AddOAuth("PropelAuth", configOptions => + .AddOAuth("OAuth", configOptions => { configOptions.AuthorizationEndpoint = $"{options.AuthUrl}/propelauth/oauth/authorize"; configOptions.TokenEndpoint = $"{options.AuthUrl}/propelauth/oauth/token"; @@ -134,6 +158,7 @@ private static void ConfigureAuthentication(IServiceCollection services, PropelA { context.Identity?.AddClaim(claim); } + return Task.CompletedTask; } };