using Gateway.Data;
using Microsoft.IdentityModel.Protocols;
using Microsoft.IdentityModel.Protocols.OpenIdConnect;
using Microsoft.IdentityModel.Tokens;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Text.Json;
namespace Gateway.Security;
///
/// Multi-provider authentication middleware.
/// Supports: Microsoft Entra ID, Google, and extensible for others.
///
/// For /api/auth/session: Validates JWT from any configured provider
/// For all other /api/*: Validates session token
///
public sealed class MultiProviderAuthMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger _logger;
private readonly IConfiguration _config;
// Paths that don't require auth
private static readonly HashSet _anonymousExact = new(StringComparer.OrdinalIgnoreCase)
{
"/",
"/health"
};
private static readonly string[] _anonymousPrefixes = { "/swagger", "/api/test" };
// OIDC config managers (cached per provider)
private static readonly Dictionary> _oidcManagers = new();
private static readonly object _oidcLock = new();
public MultiProviderAuthMiddleware(RequestDelegate next, ILogger logger, IConfiguration config)
{
_next = next;
_logger = logger;
_config = config;
}
public async Task InvokeAsync(HttpContext context, ClientContext clientContext, SqlService sql)
{
var pathRaw = context.Request.Path.Value ?? "";
var path = pathRaw.ToLowerInvariant();
var corrId = EnsureCorrelationId(context);
_logger.LogWarning("[Auth] HIT {Method} {Path} | Corr={Corr}", context.Request.Method, pathRaw, corrId);
// Anonymous paths
if (IsAnonymousPath(path))
{
SetAuthHeaders(context, corrId, "anonymous", null);
await _next(context);
return;
}
// ---------------------------------------------------------------------
// SESSION EXCHANGE: Accept JWT from any configured provider
// ---------------------------------------------------------------------
if (path.StartsWith("/api/auth/session", StringComparison.OrdinalIgnoreCase))
{
var (jwtValid, provider) = await TryMultiProviderJwtAsync(context, clientContext, corrId);
if (jwtValid)
{
SetAuthHeaders(context, corrId, $"jwt({provider})", null);
_logger.LogWarning("[Auth] Session exchange authorized via {Provider} JWT | Email={Email} | Corr={Corr}",
provider, clientContext.Email, corrId);
await _next(context);
return;
}
SetAuthHeaders(context, corrId, "jwt", "jwt-required");
_logger.LogWarning("[Auth] Session exchange denied: valid JWT required | Corr={Corr}", corrId);
context.Response.StatusCode = 401;
await context.Response.WriteAsJsonAsync(new
{
ok = false,
error = "Valid authentication required from a supported provider",
correlationId = corrId
});
return;
}
// ---------------------------------------------------------------------
// ALL OTHER /api/* PATHS: Require session token (or dev bypass)
// ---------------------------------------------------------------------
if (TryDevBypass(context, clientContext, corrId))
{
SetAuthHeaders(context, corrId, "dev-bypass", null);
await _next(context);
return;
}
if (await TrySessionAuthAsync(context, clientContext, sql, corrId))
{
SetAuthHeaders(context, corrId, "session", null);
await _next(context);
return;
}
SetAuthHeaders(context, corrId, "session", "session-required");
_logger.LogWarning("[Auth] UNAUTHORIZED: valid session required | {Path} | Corr={Corr}", pathRaw, corrId);
context.Response.StatusCode = 401;
await context.Response.WriteAsJsonAsync(new
{
ok = false,
error = "Valid session required",
correlationId = corrId
});
}
///
/// Try to validate JWT from multiple providers.
/// Returns (success, providerName).
///
private async Task<(bool Success, string? Provider)> TryMultiProviderJwtAsync(
HttpContext context,
ClientContext clientContext,
string corrId)
{
if (!context.Request.Headers.TryGetValue("Authorization", out var authHeader))
{
_logger.LogWarning("[Auth] No Authorization header | Corr={Corr}", corrId);
return (false, null);
}
var headerValue = authHeader.FirstOrDefault();
if (string.IsNullOrWhiteSpace(headerValue) || !headerValue.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
{
_logger.LogWarning("[Auth] Invalid Authorization header format | Corr={Corr}", corrId);
return (false, null);
}
var token = headerValue["Bearer ".Length..].Trim();
if (string.IsNullOrWhiteSpace(token))
{
_logger.LogWarning("[Auth] Empty bearer token | Corr={Corr}", corrId);
return (false, null);
}
// Check for provider hint from frontend
var providerHint = context.Request.Headers["X-Auth-Provider"].FirstOrDefault()?.ToLowerInvariant();
// Read token to get issuer (for auto-detection)
var handler = new JwtSecurityTokenHandler();
JwtSecurityToken? jwt = null;
if (handler.CanReadToken(token))
{
jwt = handler.ReadJwtToken(token);
_logger.LogWarning("[Auth] JWT presented | iss={Iss} aud={Aud} | Corr={Corr}",
jwt.Issuer, jwt.Audiences.FirstOrDefault(), corrId);
}
// Try providers in order (hint first, then auto-detect)
var providersToTry = new List();
if (!string.IsNullOrWhiteSpace(providerHint))
providersToTry.Add(providerHint);
// Auto-detect based on issuer
if (jwt != null)
{
if (jwt.Issuer.Contains("login.microsoftonline.com") || jwt.Issuer.Contains("sts.windows.net") || jwt.Issuer.Contains("ciamlogin.com"))
providersToTry.Add("microsoft");
else if (jwt.Issuer.Contains("accounts.google.com"))
providersToTry.Add("google");
}
// Fallback: try all configured providers
if (IsProviderConfigured("microsoft") && !providersToTry.Contains("microsoft"))
providersToTry.Add("microsoft");
if (IsProviderConfigured("google") && !providersToTry.Contains("google"))
providersToTry.Add("google");
foreach (var provider in providersToTry.Distinct())
{
var success = provider switch
{
"microsoft" => await TryValidateMicrosoftJwtAsync(token, clientContext, corrId),
"google" => await TryValidateGoogleJwtAsync(token, clientContext, corrId),
_ => false
};
if (success)
{
clientContext.AuthProvider = provider;
return (true, provider);
}
}
return (false, null);
}
///
/// Validate Microsoft Entra ID JWT
///
private async Task TryValidateMicrosoftJwtAsync(string token, ClientContext clientContext, string corrId)
{
var tenantId = _config["Auth:Microsoft:TenantId"] ?? _config["Auth:EntraId:TenantId"] ?? _config["ENTRA_TENANT_ID"];
var clientId = _config["Auth:Microsoft:ClientId"] ?? _config["Auth:EntraId:ClientId"] ?? _config["ENTRA_CLIENT_ID"];
var ciamDomain = _config["Auth:Microsoft:CiamDomain"] ?? _config["Auth:EntraId:CiamDomain"];
if (string.IsNullOrWhiteSpace(tenantId) || string.IsNullOrWhiteSpace(clientId))
{
_logger.LogWarning("[Auth] Microsoft provider not configured | Corr={Corr}", corrId);
return false;
}
try
{
// Peek at the token issuer to determine if this is a CIAM token
var handler = new JwtSecurityTokenHandler();
var jwt = handler.ReadJwtToken(token);
var isCiam = jwt.Issuer.Contains("ciamlogin.com", StringComparison.OrdinalIgnoreCase);
// Build authority + valid issuers based on token type
string authority;
string metadataAddress;
string[] validIssuers;
if (isCiam)
{
// CIAM (External ID): derive domain from issuer or config
var domain = ciamDomain;
if (string.IsNullOrWhiteSpace(domain))
{
// Extract domain from issuer, e.g. "https://USIMClients.ciamlogin.com/{tenant}/v2.0"
var issuerUri = new Uri(jwt.Issuer);
domain = issuerUri.Host;
}
authority = $"https://{domain}/{tenantId}/v2.0";
metadataAddress = $"{authority}/.well-known/openid-configuration";
validIssuers = new[]
{
$"https://{domain}/{tenantId}/v2.0",
$"https://{domain}/{tenantId}"
};
_logger.LogWarning("[Auth] CIAM token detected | domain={Domain} | Corr={Corr}", domain, corrId);
}
else
{
// Standard Entra ID — could be CIAM tenant or Staff tenant (Tech, Admin)
// Detect by comparing issuer against configured Staff tenant ID
var staffTenantId = _config["Auth:Microsoft:StaffTenantId"];
var staffClientId = _config["Auth:Microsoft:StaffClientId"];
var isStaff = !string.IsNullOrWhiteSpace(staffTenantId) &&
jwt.Issuer.Contains(staffTenantId, StringComparison.OrdinalIgnoreCase);
if (isStaff)
{
tenantId = staffTenantId!;
clientId = staffClientId ?? clientId;
_logger.LogWarning("[Auth] Staff Entra token detected | tenant={Tenant} | Corr={Corr}", tenantId, corrId);
clientContext.IsStaff = true;
}
authority = $"https://login.microsoftonline.com/{tenantId}/v2.0";
metadataAddress = $"{authority}/.well-known/openid-configuration";
validIssuers = new[]
{
$"https://login.microsoftonline.com/{tenantId}/v2.0",
$"https://sts.windows.net/{tenantId}/"
};
}
var mgr = GetOrCreateOidcManager(isCiam ? "microsoft-ciam" : "microsoft", metadataAddress);
var openIdConfig = await mgr.GetConfigurationAsync(CancellationToken.None);
var validationParams = new TokenValidationParameters
{
ValidateIssuer = true,
ValidIssuers = validIssuers,
ValidateAudience = true,
ValidAudiences = new[] { clientId, $"api://{clientId}" },
ValidateLifetime = true,
IssuerSigningKeys = openIdConfig.SigningKeys,
ClockSkew = TimeSpan.FromMinutes(5)
};
var tokenHandler = new JwtSecurityTokenHandler();
var principal = tokenHandler.ValidateToken(token, validationParams, out _);
ExtractClaims(principal, clientContext);
_logger.LogWarning("[Auth] Microsoft JWT validated ({Mode}) | sub={Sub} email={Email} | Corr={Corr}",
isCiam ? "CIAM" : "Entra", clientContext.ClientId, clientContext.Email, corrId);
return clientContext.IsAuthenticated;
}
catch (Exception ex)
{
_logger.LogWarning("[Auth] Microsoft JWT validation failed: {Msg} | Corr={Corr}", ex.Message, corrId);
return false;
}
}
///
/// Validate Google ID token
///
private async Task TryValidateGoogleJwtAsync(string token, ClientContext clientContext, string corrId)
{
var clientId = _config["Auth:Google:ClientId"] ?? _config["GOOGLE_CLIENT_ID"];
if (string.IsNullOrWhiteSpace(clientId))
{
_logger.LogWarning("[Auth] Google provider not configured | Corr={Corr}", corrId);
return false;
}
try
{
var metadataAddress = "https://accounts.google.com/.well-known/openid-configuration";
var mgr = GetOrCreateOidcManager("google", metadataAddress);
var openIdConfig = await mgr.GetConfigurationAsync(CancellationToken.None);
var validationParams = new TokenValidationParameters
{
ValidateIssuer = true,
ValidIssuers = new[] { "https://accounts.google.com", "accounts.google.com" },
ValidateAudience = true,
ValidAudiences = new[] { clientId },
ValidateLifetime = true,
IssuerSigningKeys = openIdConfig.SigningKeys,
ClockSkew = TimeSpan.FromMinutes(5)
};
var handler = new JwtSecurityTokenHandler();
var principal = handler.ValidateToken(token, validationParams, out _);
ExtractClaims(principal, clientContext);
_logger.LogWarning("[Auth] Google JWT validated | sub={Sub} email={Email} | Corr={Corr}",
clientContext.ClientId, clientContext.Email, corrId);
return clientContext.IsAuthenticated;
}
catch (Exception ex)
{
_logger.LogWarning("[Auth] Google JWT validation failed: {Msg} | Corr={Corr}", ex.Message, corrId);
return false;
}
}
///
/// Extract standard claims into ClientContext
///
private static void ExtractClaims(ClaimsPrincipal principal, ClientContext clientContext)
{
// Always extract oid explicitly — used for activity logging and identity.
// For standard Entra access tokens oid may be under the full claim URI.
var oid = principal.FindFirstValue("oid")
?? principal.FindFirstValue("http://schemas.microsoft.com/identity/claims/objectidentifier");
clientContext.EntraOid = oid;
// ClientId: prefer oid, fall back to sub
clientContext.ClientId =
oid ??
principal.FindFirstValue("sub") ??
principal.FindFirstValue(ClaimTypes.NameIdentifier);
clientContext.Email =
principal.FindFirstValue("email") ??
principal.FindFirstValue("preferred_username") ??
principal.FindFirstValue(ClaimTypes.Email);
clientContext.ClientName =
principal.FindFirstValue("name") ??
principal.FindFirstValue(ClaimTypes.Name);
clientContext.IsDevBypass = false;
}
///
/// Session token validation (unchanged from original)
///
private async Task TrySessionAuthAsync(HttpContext context, ClientContext clientContext, SqlService sql, string corrId)
{
string? token = null;
// Check X-Session-Token header first
if (context.Request.Headers.TryGetValue("X-Session-Token", out var sessionHeader))
token = sessionHeader.FirstOrDefault();
// Fall back to Authorization: Bearer (session token, not JWT)
if (string.IsNullOrWhiteSpace(token) && context.Request.Headers.TryGetValue("Authorization", out var authHeader))
{
var auth = authHeader.FirstOrDefault();
if (!string.IsNullOrWhiteSpace(auth) && auth.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
token = auth["Bearer ".Length..].Trim();
}
if (string.IsNullOrWhiteSpace(token))
{
_logger.LogWarning("[Auth] No session token provided | Corr={Corr}", corrId);
return false;
}
_logger.LogWarning("[Auth] Session validation starting | Corr={Corr}", corrId);
try
{
var rqst = JsonSerializer.Serialize(new { sessionToken = token });
var sessionProc = "dbo.spClientSession"; // Gateway handles CIAM client sessions only
var resp = await sql.ExecProcAsync(sessionProc, "validate", rqst, ct: context.RequestAborted);
if (string.IsNullOrWhiteSpace(resp))
{
_logger.LogWarning("[Auth] Session validation failed: empty response | Corr={Corr}", corrId);
return false;
}
using var doc = JsonDocument.Parse(resp);
var root = doc.RootElement;
if (root.TryGetProperty("ok", out var okProp) && okProp.ValueKind == JsonValueKind.True)
{
var data = root.TryGetProperty("data", out var dataProp) ? dataProp : root;
clientContext.SessionId = data.TryGetProperty("sessionId", out var sid) ? sid.GetString() : null;
clientContext.ClientId = data.TryGetProperty("clientId", out var cid) ? cid.GetString() : null;
clientContext.ClientName = data.TryGetProperty("clientName", out var cn) ? cn.GetString() : null;
clientContext.UserId = data.TryGetProperty("userId", out var uid) ? uid.GetString() : null;
clientContext.Email = data.TryGetProperty("userEmail", out var ue) ? ue.GetString() : null;
clientContext.Role = data.TryGetProperty("role", out var role) ? role.GetString() : null;
clientContext.IsDevBypass = false;
// TenantId: session data first, then X-Tenant-Id header fallback
// (In agency model, this is the client's Google Ads customer ID)
clientContext.TenantId =
data.TryGetProperty("tenantId", out var tenId) ? tenId.GetString() :
data.TryGetProperty("googleCustomerId", out var gcid) ? gcid.GetString() :
null;
// Fall back to X-Tenant-Id header if not in session data
if (string.IsNullOrWhiteSpace(clientContext.TenantId) &&
context.Request.Headers.TryGetValue("X-Tenant-Id", out var tenantHeader))
{
clientContext.TenantId = tenantHeader.FirstOrDefault();
}
_logger.LogWarning("[Auth] Session validated OK | ClientId={ClientId} Email={Email} IsAdmin={IsAdmin} | Corr={Corr}",
clientContext.ClientId, clientContext.Email, clientContext.IsAdmin, corrId);
return clientContext.IsAuthenticated;
}
_logger.LogWarning("[Auth] Session validation failed: ok=false | Corr={Corr}", corrId);
return false;
}
catch (Exception ex)
{
_logger.LogError(ex, "[Auth] Session validation error | Corr={Corr}", corrId);
return false;
}
}
///
/// Development bypass
///
private bool TryDevBypass(HttpContext context, ClientContext clientContext, string corrId)
{
var env = _config["ASPNETCORE_ENVIRONMENT"] ?? Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT");
var allowBypass = _config.GetValue("Auth:AllowDevBypass");
if (!string.Equals(env, "Development", StringComparison.OrdinalIgnoreCase) && !allowBypass)
return false;
if (!context.Request.Headers.TryGetValue("X-Dev-ClientId", out var devClientId))
return false;
var clientId = devClientId.FirstOrDefault();
if (string.IsNullOrWhiteSpace(clientId))
return false;
clientContext.ClientId = clientId;
clientContext.IsDevBypass = true;
if (context.Request.Headers.TryGetValue("X-Dev-TenantId", out var devTenantId))
clientContext.TenantId = devTenantId.FirstOrDefault();
_logger.LogWarning("[Auth] Dev bypass OK | ClientId={ClientId} | Corr={Corr}", clientId, corrId);
return true;
}
private bool IsProviderConfigured(string provider)
{
return provider switch
{
"microsoft" => !string.IsNullOrWhiteSpace(
_config["Auth:Microsoft:ClientId"] ?? _config["Auth:EntraId:ClientId"] ?? _config["ENTRA_CLIENT_ID"]),
"google" => !string.IsNullOrWhiteSpace(
_config["Auth:Google:ClientId"] ?? _config["GOOGLE_CLIENT_ID"]),
_ => false
};
}
private static bool IsAnonymousPath(string pathLower)
{
if (_anonymousExact.Contains(pathLower))
return true;
return _anonymousPrefixes.Any(p => pathLower.StartsWith(p, StringComparison.OrdinalIgnoreCase));
}
private static string EnsureCorrelationId(HttpContext context)
{
const string header = "X-Correlation-Id";
if (!context.Request.Headers.TryGetValue(header, out var existing) || string.IsNullOrWhiteSpace(existing.FirstOrDefault()))
{
var id = Guid.NewGuid().ToString("N");
context.Request.Headers[header] = id;
return id;
}
return existing.First()!;
}
private static void SetAuthHeaders(HttpContext context, string corrId, string authPath, string? authFail)
{
context.Response.Headers["X-Correlation-Id"] = corrId;
context.Response.Headers["X-Auth-Path"] = authPath;
if (!string.IsNullOrWhiteSpace(authFail))
context.Response.Headers["X-Auth-Fail"] = authFail;
}
private static ConfigurationManager GetOrCreateOidcManager(string provider, string metadataAddress)
{
lock (_oidcLock)
{
if (!_oidcManagers.TryGetValue(provider, out var mgr))
{
mgr = new ConfigurationManager(
metadataAddress,
new OpenIdConnectConfigurationRetriever());
_oidcManagers[provider] = mgr;
}
return mgr;
}
}
}