Files
AdPlatform-Server/Gateway/Security/AccessLogMiddleware.cs
2026-02-03 15:04:37 -08:00

170 lines
5.3 KiB
C#

using Gateway.Data;
using System.Diagnostics;
using System.Text.Json;
namespace Gateway.Security;
/// <summary>
/// Logs all HTTP requests to tbAccessLog for security monitoring and debugging.
/// Should be registered early in the pipeline (after routing, before auth).
/// Logs asynchronously to avoid impacting response time.
/// </summary>
public sealed class AccessLogMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<AccessLogMiddleware> _logger;
// Paths to skip logging (health checks, static files, etc.)
private static readonly HashSet<string> _skipPaths = new(StringComparer.OrdinalIgnoreCase)
{
"/health",
"/favicon.ico"
};
public AccessLogMiddleware(RequestDelegate next, ILogger<AccessLogMiddleware> logger)
{
_next = next;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context, ClientContext clientContext, SqlService sql)
{
var path = context.Request.Path.Value ?? "/";
// Skip logging for noisy endpoints
if (ShouldSkip(path))
{
await _next(context);
return;
}
var stopwatch = Stopwatch.StartNew();
string? errorCode = null;
string? errorMessage = null;
try
{
await _next(context);
}
catch (Exception ex)
{
errorCode = "unhandled-exception";
errorMessage = ex.Message;
throw; // Re-throw to let error handling middleware deal with it
}
finally
{
stopwatch.Stop();
// Capture error info from response headers if set by auth middleware
if (context.Response.Headers.TryGetValue("X-Auth-Fail", out var authFail))
{
errorCode = authFail.FirstOrDefault();
}
// Fire-and-forget logging (don't await)
_ = LogAccessAsync(sql, context, clientContext, stopwatch.ElapsedMilliseconds, errorCode, errorMessage);
}
}
private static bool ShouldSkip(string path)
{
if (_skipPaths.Contains(path))
return true;
// Skip swagger
if (path.StartsWith("/swagger", StringComparison.OrdinalIgnoreCase))
return true;
return false;
}
private async Task LogAccessAsync(
SqlService sql,
HttpContext context,
ClientContext clientContext,
long durationMs,
string? errorCode,
string? errorMessage)
{
try
{
var correlationId = context.Request.Headers["X-Correlation-Id"].FirstOrDefault()
?? context.Response.Headers["X-Correlation-Id"].FirstOrDefault();
var authPath = context.Response.Headers["X-Auth-Path"].FirstOrDefault();
var rqst = JsonSerializer.Serialize(new
{
correlationId,
method = context.Request.Method,
path = context.Request.Path.Value,
queryString = context.Request.QueryString.HasValue
? SanitizeQueryString(context.Request.QueryString.Value)
: null,
authPath,
userId = clientContext.UserId,
clientId = clientContext.ClientId,
sessionId = clientContext.SessionId,
statusCode = context.Response.StatusCode,
durationMs,
ipAddress = GetClientIp(context),
userAgent = context.Request.Headers.UserAgent.FirstOrDefault(),
errorCode,
errorMessage
});
await sql.ExecProcAsync("dbo.spAccessLog", "log", rqst);
}
catch (Exception ex)
{
// Don't let logging failures affect the response
_logger.LogError(ex, "Failed to write access log");
}
}
private static string? GetClientIp(HttpContext context)
{
// Check X-Forwarded-For first (for requests behind load balancer/proxy)
var forwarded = context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
if (!string.IsNullOrWhiteSpace(forwarded))
{
return forwarded.Split(',')[0].Trim();
}
return context.Connection.RemoteIpAddress?.ToString();
}
private static string? SanitizeQueryString(string? queryString)
{
if (string.IsNullOrWhiteSpace(queryString))
return null;
// Remove sensitive params (add more as needed)
var sensitiveParams = new[] { "token", "key", "secret", "password", "apikey" };
foreach (var param in sensitiveParams)
{
// Simple regex-free approach: just note that sensitive data may be present
if (queryString.Contains(param, StringComparison.OrdinalIgnoreCase))
{
return "[REDACTED]";
}
}
// Truncate if too long
return queryString.Length > 1000 ? queryString[..1000] : queryString;
}
}
/// <summary>
/// Extension method for cleaner registration in Program.cs
/// </summary>
public static class AccessLogMiddlewareExtensions
{
public static IApplicationBuilder UseAccessLogging(this IApplicationBuilder builder)
{
return builder.UseMiddleware<AccessLogMiddleware>();
}
}