ApiAuthenticationMiddleware performance improvements

Previously we've used one semaphore per all ongoing authentication attempts, which is suboptimal given the existence of a lot of consumers, including ongoing (D)DoS or distributed bruteforce attack. ASF should be as resistant to that as possible, therefore it makes sense to replace the global semaphore with per-IP semaphore (actually task), that can control the access just as well, without stopping other consumers from accessing the same authentication process concurrently.
This commit is contained in:
Archi
2021-08-24 01:37:14 +02:00
parent 47855ca705
commit 69e2a3590c
2 changed files with 31 additions and 14 deletions

View File

@@ -19,6 +19,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if NETFRAMEWORK
using JustArchiNET.Madness;
#endif
using System;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
@@ -45,7 +48,7 @@ namespace ArchiSteamFarm.IPC.Integration {
private const byte FailedAuthorizationsCooldownInHours = 1;
private const byte MaxFailedAuthorizationAttempts = 5;
private static readonly SemaphoreSlim AuthorizationSemaphore = new(1, 1);
private static readonly ConcurrentDictionary<IPAddress, Task> AuthorizationTasks = new();
private static readonly Timer ClearFailedAuthorizationsTimer = new(ClearFailedAuthorizations);
private static readonly ConcurrentDictionary<IPAddress, byte> FailedAuthorizations = new();
@@ -150,23 +153,37 @@ namespace ArchiSteamFarm.IPC.Integration {
bool authorized = ipcPassword == inputHash;
await AuthorizationSemaphore.WaitAsync().ConfigureAwait(false);
while (true) {
if (AuthorizationTasks.TryGetValue(clientIP, out Task? task)) {
await task.ConfigureAwait(false);
try {
bool hasFailedAuthorizations = FailedAuthorizations.TryGetValue(clientIP, out attempts);
if (hasFailedAuthorizations && (attempts >= MaxFailedAuthorizationAttempts)) {
return (HttpStatusCode.Forbidden, false);
continue;
}
if (!authorized) {
FailedAuthorizations[clientIP] = hasFailedAuthorizations ? ++attempts : (byte) 1;
TaskCompletionSource taskCompletionSource = new();
if (!AuthorizationTasks.TryAdd(clientIP, taskCompletionSource.Task)) {
continue;
}
} finally {
AuthorizationSemaphore.Release();
try {
bool hasFailedAuthorizations = FailedAuthorizations.TryGetValue(clientIP, out attempts);
if (hasFailedAuthorizations && (attempts >= MaxFailedAuthorizationAttempts)) {
return (HttpStatusCode.Forbidden, false);
}
if (!authorized) {
FailedAuthorizations[clientIP] = hasFailedAuthorizations ? ++attempts : (byte) 1;
}
} finally {
AuthorizationTasks.TryRemove(clientIP, out _);
taskCompletionSource.SetResult();
}
return (authorized ? HttpStatusCode.OK : HttpStatusCode.Unauthorized, true);
}
return (authorized ? HttpStatusCode.OK : HttpStatusCode.Unauthorized, true);
}
}
}