170 lines
5.3 KiB
C#
170 lines
5.3 KiB
C#
using Gateway.Data;
|
|
using System.Diagnostics;
|
|
using System.Text.Json;
|
|
|
|
namespace Gateway.Security;
|
|
|
|
/// <summary>
|
|
/// Logs all HTTP requests to tbAccessLog for security monitoring and debugging.
|
|
/// Should be registered early in the pipeline (after routing, before auth).
|
|
/// Logs asynchronously to avoid impacting response time.
|
|
/// </summary>
|
|
public sealed class AccessLogMiddleware
|
|
{
|
|
private readonly RequestDelegate _next;
|
|
private readonly ILogger<AccessLogMiddleware> _logger;
|
|
|
|
// Paths to skip logging (health checks, static files, etc.)
|
|
private static readonly HashSet<string> _skipPaths = new(StringComparer.OrdinalIgnoreCase)
|
|
{
|
|
"/health",
|
|
"/favicon.ico"
|
|
};
|
|
|
|
public AccessLogMiddleware(RequestDelegate next, ILogger<AccessLogMiddleware> logger)
|
|
{
|
|
_next = next;
|
|
_logger = logger;
|
|
}
|
|
|
|
public async Task InvokeAsync(HttpContext context, ClientContext clientContext, SqlService sql)
|
|
{
|
|
var path = context.Request.Path.Value ?? "/";
|
|
|
|
// Skip logging for noisy endpoints
|
|
if (ShouldSkip(path))
|
|
{
|
|
await _next(context);
|
|
return;
|
|
}
|
|
|
|
var stopwatch = Stopwatch.StartNew();
|
|
string? errorCode = null;
|
|
string? errorMessage = null;
|
|
|
|
try
|
|
{
|
|
await _next(context);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
errorCode = "unhandled-exception";
|
|
errorMessage = ex.Message;
|
|
throw; // Re-throw to let error handling middleware deal with it
|
|
}
|
|
finally
|
|
{
|
|
stopwatch.Stop();
|
|
|
|
// Capture error info from response headers if set by auth middleware
|
|
if (context.Response.Headers.TryGetValue("X-Auth-Fail", out var authFail))
|
|
{
|
|
errorCode = authFail.FirstOrDefault();
|
|
}
|
|
|
|
// Fire-and-forget logging (don't await)
|
|
_ = LogAccessAsync(sql, context, clientContext, stopwatch.ElapsedMilliseconds, errorCode, errorMessage);
|
|
}
|
|
}
|
|
|
|
private static bool ShouldSkip(string path)
|
|
{
|
|
if (_skipPaths.Contains(path))
|
|
return true;
|
|
|
|
// Skip swagger
|
|
if (path.StartsWith("/swagger", StringComparison.OrdinalIgnoreCase))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
private async Task LogAccessAsync(
|
|
SqlService sql,
|
|
HttpContext context,
|
|
ClientContext clientContext,
|
|
long durationMs,
|
|
string? errorCode,
|
|
string? errorMessage)
|
|
{
|
|
try
|
|
{
|
|
var correlationId = context.Request.Headers["X-Correlation-Id"].FirstOrDefault()
|
|
?? context.Response.Headers["X-Correlation-Id"].FirstOrDefault();
|
|
|
|
var authPath = context.Response.Headers["X-Auth-Path"].FirstOrDefault();
|
|
|
|
var rqst = JsonSerializer.Serialize(new
|
|
{
|
|
correlationId,
|
|
method = context.Request.Method,
|
|
path = context.Request.Path.Value,
|
|
queryString = context.Request.QueryString.HasValue
|
|
? SanitizeQueryString(context.Request.QueryString.Value)
|
|
: null,
|
|
authPath,
|
|
userId = clientContext.UserId,
|
|
clientId = clientContext.ClientId,
|
|
sessionId = clientContext.SessionId,
|
|
statusCode = context.Response.StatusCode,
|
|
durationMs,
|
|
ipAddress = GetClientIp(context),
|
|
userAgent = context.Request.Headers.UserAgent.FirstOrDefault(),
|
|
errorCode,
|
|
errorMessage
|
|
});
|
|
|
|
await sql.ExecProcAsync("dbo.spAccessLog", "log", rqst);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
// Don't let logging failures affect the response
|
|
_logger.LogError(ex, "Failed to write access log");
|
|
}
|
|
}
|
|
|
|
private static string? GetClientIp(HttpContext context)
|
|
{
|
|
// Check X-Forwarded-For first (for requests behind load balancer/proxy)
|
|
var forwarded = context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
|
|
if (!string.IsNullOrWhiteSpace(forwarded))
|
|
{
|
|
return forwarded.Split(',')[0].Trim();
|
|
}
|
|
|
|
return context.Connection.RemoteIpAddress?.ToString();
|
|
}
|
|
|
|
private static string? SanitizeQueryString(string? queryString)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(queryString))
|
|
return null;
|
|
|
|
// Remove sensitive params (add more as needed)
|
|
var sensitiveParams = new[] { "token", "key", "secret", "password", "apikey" };
|
|
|
|
foreach (var param in sensitiveParams)
|
|
{
|
|
// Simple regex-free approach: just note that sensitive data may be present
|
|
if (queryString.Contains(param, StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
return "[REDACTED]";
|
|
}
|
|
}
|
|
|
|
// Truncate if too long
|
|
return queryString.Length > 1000 ? queryString[..1000] : queryString;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Extension method for cleaner registration in Program.cs
|
|
/// </summary>
|
|
public static class AccessLogMiddlewareExtensions
|
|
{
|
|
public static IApplicationBuilder UseAccessLogging(this IApplicationBuilder builder)
|
|
{
|
|
return builder.UseMiddleware<AccessLogMiddleware>();
|
|
}
|
|
}
|