-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #290 from rGunti/feature/rate-limit
Implemented basic rate limiting
- Loading branch information
Showing
8 changed files
with
160 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
using System.Net; | ||
using System.Net.Sockets; | ||
|
||
namespace FloppyBot.WebApi.Agent.Utils; | ||
|
||
public static class HttpContextHelpers | ||
{ | ||
private const string HEADER_FORWARDED_FOR = "X-Forwarded-For"; | ||
private const string HEADER_REAL_IP = "X-Real-IP"; | ||
|
||
public static IPAddress? GetRemoteHostIpFromHeaders(this HttpContext httpContext) | ||
{ | ||
return httpContext | ||
.Request.Headers.GetCommaSeparatedValues(HEADER_REAL_IP) | ||
.Concat(httpContext.Request.Headers.GetCommaSeparatedValues(HEADER_FORWARDED_FOR)) | ||
.Select(ip => IPAddress.TryParse(ip, out var address) ? address : null) | ||
.FirstOrDefault( | ||
ip => | ||
ip?.AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6, | ||
httpContext.Connection.RemoteIpAddress | ||
); | ||
} | ||
|
||
public static ILogger GetLogger(this HttpContext httpContext, string categoryName) | ||
{ | ||
return httpContext | ||
.RequestServices.GetRequiredService<ILoggerFactory>() | ||
.CreateLogger(categoryName); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
using System.Globalization; | ||
using System.Security.Cryptography; | ||
using System.Text; | ||
using System.Threading.RateLimiting; | ||
using Microsoft.AspNetCore.RateLimiting; | ||
using Microsoft.Extensions.Options; | ||
|
||
namespace FloppyBot.WebApi.Agent.Utils; | ||
|
||
internal static class Limiters | ||
{ | ||
private const string KEY_GLOBAL = "global"; | ||
private const string KEY_AUTH = "auth"; | ||
|
||
private const string LOGGER_CATEGORY = "RateLimiter"; | ||
|
||
private const string SECTION_DEFAULT = "RateLimiter:Default"; | ||
private const string SECTION_AUTH = "RateLimiter:Authenticated"; | ||
|
||
internal static IServiceCollection ConfigureRateLimiter( | ||
this IServiceCollection services, | ||
IConfiguration config | ||
) | ||
{ | ||
return services | ||
.Configure<TokenBucketRateLimiterOptions>( | ||
KEY_GLOBAL, | ||
o => config.GetRequiredSection(SECTION_DEFAULT).Bind(o) | ||
) | ||
.Configure<TokenBucketRateLimiterOptions>( | ||
KEY_AUTH, | ||
o => config.GetRequiredSection(SECTION_AUTH).Bind(o) | ||
) | ||
.AddRateLimiter(rl => | ||
{ | ||
rl.OnRejected = RateLimiter_OnRejected; | ||
rl.RejectionStatusCode = StatusCodes.Status429TooManyRequests; | ||
rl.GlobalLimiter = PartitionedRateLimiter.Create<HttpContext, string>( | ||
RateLimiter_Build | ||
); | ||
}); | ||
} | ||
|
||
private static RateLimitPartition<string> RateLimiter_Build(HttpContext httpContext) | ||
{ | ||
var logger = httpContext.GetLogger(LOGGER_CATEGORY); | ||
|
||
string? accessToken = httpContext.Request.Headers.Authorization.ToString().HashString(); | ||
string? remoteIp = httpContext.GetRemoteHostIpFromHeaders()?.ToString(); | ||
|
||
var partitionKey = accessToken ?? remoteIp ?? KEY_GLOBAL; | ||
logger.LogTrace("Building rate limiter for partition={PartitionKey}", partitionKey); | ||
return RateLimitPartition.GetTokenBucketLimiter( | ||
accessToken ?? remoteIp ?? KEY_GLOBAL, | ||
_ => | ||
httpContext | ||
.RequestServices.GetRequiredService< | ||
IOptionsFactory<TokenBucketRateLimiterOptions> | ||
>() | ||
.Create(!string.IsNullOrWhiteSpace(accessToken) ? KEY_AUTH : KEY_GLOBAL) | ||
); | ||
} | ||
|
||
private static async ValueTask RateLimiter_OnRejected( | ||
OnRejectedContext context, | ||
CancellationToken cancellationToken | ||
) | ||
{ | ||
var response = context.HttpContext.Response; | ||
if (context.Lease.TryGetMetadata(MetadataName.RetryAfter, out var retryAfter)) | ||
{ | ||
// Add a Retry-After header to the response | ||
response.Headers.RetryAfter = ((int)retryAfter.TotalSeconds).ToString( | ||
NumberFormatInfo.InvariantInfo | ||
); | ||
} | ||
|
||
response.StatusCode = StatusCodes.Status429TooManyRequests; | ||
|
||
context | ||
.HttpContext.GetLogger(LOGGER_CATEGORY) | ||
.LogWarning( | ||
"Request rejected from addr={UserRequestAddress} to path={UserRequestPath}", | ||
context.HttpContext.GetRemoteHostIpFromHeaders(), | ||
context.HttpContext.Request.Path | ||
); | ||
await response.WriteAsync( | ||
"Whoo there, calm down, mate! You're exceeding the speed limit here, gonna call the cops next time.", | ||
cancellationToken | ||
); | ||
} | ||
|
||
private static string? HashString(this string? s) | ||
{ | ||
return string.IsNullOrWhiteSpace(s) | ||
? null | ||
: Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(s))); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters