@@ -1,26 +1,109 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
namespace MQTTnet.Internal | |||
{ | |||
public sealed class AsyncAutoResetEvent : IDisposable | |||
// Inspired from Stephen Toub (https://blogs.msdn.microsoft.com/pfxteam/2012/02/11/building-async-coordination-primitives-part-2-asyncautoresetevent/) and Chris Gillum (https://stackoverflow.com/a/43012490) | |||
public class AsyncAutoResetEvent | |||
{ | |||
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(0, 1); | |||
private readonly LinkedList<TaskCompletionSource<bool>> waiters = new LinkedList<TaskCompletionSource<bool>>(); | |||
public Task WaitOneAsync(CancellationToken cancellationToken) | |||
private bool isSignaled; | |||
public AsyncAutoResetEvent(bool signaled = false) | |||
{ | |||
return _semaphore.WaitAsync(cancellationToken); | |||
this.isSignaled = signaled; | |||
} | |||
public void Set() | |||
public Task<bool> WaitOneAsync() | |||
{ | |||
return this.WaitOneAsync(CancellationToken.None); | |||
} | |||
public Task<bool> WaitOneAsync(TimeSpan timeout) | |||
{ | |||
_semaphore.Release(); | |||
return this.WaitOneAsync(timeout, CancellationToken.None); | |||
} | |||
public void Dispose() | |||
public Task<bool> WaitOneAsync(CancellationToken cancellationToken) | |||
{ | |||
_semaphore?.Dispose(); | |||
return this.WaitOneAsync(Timeout.InfiniteTimeSpan, cancellationToken); | |||
} | |||
public async Task<bool> WaitOneAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
cancellationToken.ThrowIfCancellationRequested(); | |||
TaskCompletionSource<bool> tcs; | |||
lock (this.waiters) | |||
{ | |||
if (this.isSignaled) | |||
{ | |||
this.isSignaled = false; | |||
return true; | |||
} | |||
else if (timeout == TimeSpan.Zero) | |||
{ | |||
return this.isSignaled; | |||
} | |||
else | |||
{ | |||
tcs = new TaskCompletionSource<bool>(); | |||
this.waiters.AddLast(tcs); | |||
} | |||
} | |||
Task winner = await Task.WhenAny(tcs.Task, Task.Delay(timeout, cancellationToken)); | |||
if (winner == tcs.Task) | |||
{ | |||
// The task was signaled. | |||
return true; | |||
} | |||
else | |||
{ | |||
// We timed-out; remove our reference to the task. | |||
// This is an O(n) operation since waiters is a LinkedList<T>. | |||
lock (this.waiters) | |||
{ | |||
bool removed = this.waiters.Remove(tcs); | |||
if (winner.Status == TaskStatus.Canceled) | |||
{ | |||
throw new OperationCanceledException(cancellationToken); | |||
} | |||
else | |||
{ | |||
throw new TimeoutException(); | |||
} | |||
} | |||
} | |||
} | |||
public void Set() | |||
{ | |||
TaskCompletionSource<bool> toRelease = null; | |||
lock (this.waiters) | |||
{ | |||
if (this.waiters.Count > 0) | |||
{ | |||
// Signal the first task in the waiters list. | |||
toRelease = this.waiters.First.Value; | |||
this.waiters.RemoveFirst(); | |||
} | |||
else if (!this.isSignaled) | |||
{ | |||
// No tasks are pending | |||
this.isSignaled = true; | |||
} | |||
} | |||
if (toRelease != null) | |||
{ | |||
toRelease.SetResult(true); | |||
} | |||
} | |||
} | |||
} |
@@ -4,23 +4,32 @@ using System.Threading.Tasks; | |||
namespace MQTTnet.Internal | |||
{ | |||
public sealed class AsyncLock : IDisposable | |||
// From Stephen Toub (https://blogs.msdn.microsoft.com/pfxteam/2012/02/12/building-async-coordination-primitives-part-6-asynclock/) | |||
public sealed class AsyncLock | |||
{ | |||
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); | |||
private readonly SemaphoreSlim m_semaphore = new SemaphoreSlim(1, 1); | |||
private readonly Task<IDisposable> m_releaser; | |||
public Task EnterAsync(CancellationToken cancellationToken) | |||
public AsyncLock() | |||
{ | |||
return _semaphore.WaitAsync(cancellationToken); | |||
m_releaser = Task.FromResult((IDisposable)new Releaser(this)); | |||
} | |||
public void Exit() | |||
public Task<IDisposable> LockAsync(CancellationToken cancellationToken) | |||
{ | |||
_semaphore.Release(); | |||
var wait = m_semaphore.WaitAsync(cancellationToken); | |||
return wait.IsCompleted ? | |||
m_releaser : | |||
wait.ContinueWith((_, state) => (IDisposable)state, | |||
m_releaser.Result, cancellationToken, | |||
TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); | |||
} | |||
public void Dispose() | |||
private sealed class Releaser : IDisposable | |||
{ | |||
_semaphore?.Dispose(); | |||
private readonly AsyncLock m_toRelease; | |||
internal Releaser(AsyncLock toRelease) { m_toRelease = toRelease; } | |||
public void Dispose() { m_toRelease.m_semaphore.Release(); } | |||
} | |||
} | |||
} |
@@ -154,7 +154,6 @@ namespace MQTTnet.Server | |||
public void Dispose() | |||
{ | |||
_queueAutoResetEvent?.Dispose(); | |||
} | |||
} | |||
} |
@@ -119,8 +119,7 @@ namespace MQTTnet.Server | |||
public async Task StopAsync() | |||
{ | |||
await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
foreach (var session in _sessions) | |||
{ | |||
@@ -129,16 +128,11 @@ namespace MQTTnet.Server | |||
_sessions.Clear(); | |||
} | |||
finally | |||
{ | |||
_sessionsLock.Exit(); | |||
} | |||
} | |||
public async Task<IList<ConnectedMqttClient>> GetConnectedClientsAsync() | |||
{ | |||
await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
return _sessions.Where(s => s.Value.IsConnected).Select(s => new ConnectedMqttClient | |||
{ | |||
@@ -149,10 +143,6 @@ namespace MQTTnet.Server | |||
PendingApplicationMessages = s.Value.PendingMessagesQueue.Count | |||
}).ToList(); | |||
} | |||
finally | |||
{ | |||
_sessionsLock.Exit(); | |||
} | |||
} | |||
public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) | |||
@@ -165,8 +155,7 @@ namespace MQTTnet.Server | |||
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); | |||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||
await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
if (!_sessions.TryGetValue(clientId, out var session)) | |||
{ | |||
@@ -175,10 +164,6 @@ namespace MQTTnet.Server | |||
await session.SubscribeAsync(topicFilters).ConfigureAwait(false); | |||
} | |||
finally | |||
{ | |||
_sessionsLock.Exit(); | |||
} | |||
} | |||
public async Task UnsubscribeAsync(string clientId, IList<string> topicFilters) | |||
@@ -186,8 +171,7 @@ namespace MQTTnet.Server | |||
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); | |||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||
await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
if (!_sessions.TryGetValue(clientId, out var session)) | |||
{ | |||
@@ -196,15 +180,10 @@ namespace MQTTnet.Server | |||
await session.UnsubscribeAsync(topicFilters).ConfigureAwait(false); | |||
} | |||
finally | |||
{ | |||
_sessionsLock.Exit(); | |||
} | |||
} | |||
public void Dispose() | |||
{ | |||
_sessionsLock?.Dispose(); | |||
} | |||
private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) | |||
@@ -226,8 +205,8 @@ namespace MQTTnet.Server | |||
private async Task<GetOrCreateClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket) | |||
{ | |||
await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); | |||
if (isSessionPresent) | |||
@@ -261,10 +240,6 @@ namespace MQTTnet.Server | |||
return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; | |||
} | |||
finally | |||
{ | |||
_sessionsLock.Exit(); | |||
} | |||
} | |||
private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) | |||
@@ -294,18 +269,13 @@ namespace MQTTnet.Server | |||
_logger.Error(exception, "Error while processing application message"); | |||
} | |||
await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
foreach (var clientSession in _sessions.Values) | |||
{ | |||
await clientSession.EnqueueApplicationMessageAsync(applicationMessage).ConfigureAwait(false); | |||
} | |||
} | |||
finally | |||
{ | |||
_sessionsLock.Exit(); | |||
} | |||
} | |||
private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) | |||
@@ -29,43 +29,39 @@ namespace MQTTnet.Server | |||
return; | |||
} | |||
await _messagesLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); | |||
try | |||
{ | |||
var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); | |||
_messages.Clear(); | |||
foreach (var retainedMessage in retainedMessages) | |||
_messages.Clear(); | |||
foreach (var retainedMessage in retainedMessages) | |||
{ | |||
_messages[retainedMessage.Topic] = retainedMessage; | |||
} | |||
} | |||
catch (Exception exception) | |||
{ | |||
_messages[retainedMessage.Topic] = retainedMessage; | |||
_logger.Error(exception, "Unhandled exception while loading retained messages."); | |||
} | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Unhandled exception while loading retained messages."); | |||
} | |||
finally | |||
{ | |||
_messagesLock.Exit(); | |||
} | |||
} | |||
public async Task HandleMessageAsync(string clientId, MqttApplicationMessage applicationMessage) | |||
{ | |||
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); | |||
await _messagesLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
{ | |||
await HandleMessageInternalAsync(clientId, applicationMessage); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Unhandled exception while handling retained messages."); | |||
} | |||
finally | |||
using (var releaser = await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
_messagesLock.Exit(); | |||
try | |||
{ | |||
await HandleMessageInternalAsync(clientId, applicationMessage); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Unhandled exception while handling retained messages."); | |||
} | |||
} | |||
} | |||
@@ -73,8 +69,7 @@ namespace MQTTnet.Server | |||
{ | |||
var retainedMessages = new List<MqttApplicationMessage>(); | |||
await _messagesLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); | |||
try | |||
using (var releaser = await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
{ | |||
foreach (var retainedMessage in _messages.Values) | |||
{ | |||
@@ -90,17 +85,12 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
} | |||
finally | |||
{ | |||
_messagesLock.Exit(); | |||
} | |||
return retainedMessages; | |||
} | |||
public void Dispose() | |||
{ | |||
_messagesLock?.Dispose(); | |||
} | |||
private async Task HandleMessageInternalAsync(string clientId, MqttApplicationMessage applicationMessage) | |||
@@ -1,34 +1,210 @@ | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Internal; | |||
using System; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
namespace MQTTnet.Core.Tests | |||
{ | |||
[TestClass] | |||
// Inspired from the vs-threading tests (https://github.com/Microsoft/vs-threading/blob/master/src/Microsoft.VisualStudio.Threading.Tests/AsyncAutoResetEventTests.cs) | |||
public class AsyncAutoResetEventTests | |||
{ | |||
private AsyncAutoResetEvent evt; | |||
public AsyncAutoResetEventTests() | |||
{ | |||
this.evt = new AsyncAutoResetEvent(); | |||
} | |||
[TestMethod] | |||
public async Task SingleThreadedPulse() | |||
{ | |||
for (int i = 0; i < 5; i++) | |||
{ | |||
var t = this.evt.WaitOneAsync(); | |||
Assert.IsFalse(t.IsCompleted); | |||
this.evt.Set(); | |||
await t; | |||
Assert.IsTrue(t.IsCompleted); | |||
} | |||
} | |||
[TestMethod] | |||
public async Task MultipleSetOnlySignalsOnce() | |||
{ | |||
this.evt.Set(); | |||
this.evt.Set(); | |||
await this.evt.WaitOneAsync(); | |||
var t = this.evt.WaitOneAsync(); | |||
Assert.IsFalse(t.IsCompleted); | |||
await Task.Delay(500); | |||
Assert.IsFalse(t.IsCompleted); | |||
this.evt.Set(); | |||
await t; | |||
Assert.IsTrue(t.IsCompleted); | |||
} | |||
[TestMethod] | |||
public async Task OrderPreservingQueue() | |||
{ | |||
var waiters = new Task[5]; | |||
for (int i = 0; i < waiters.Length; i++) | |||
{ | |||
waiters[i] = this.evt.WaitOneAsync(); | |||
} | |||
for (int i = 0; i < waiters.Length; i++) | |||
{ | |||
this.evt.Set(); | |||
await waiters[i]; | |||
} | |||
} | |||
/// <summary> | |||
/// Verifies that inlining continuations do not have to complete execution before Set() returns. | |||
/// </summary> | |||
[TestMethod] | |||
public async Task SetReturnsBeforeInlinedContinuations() | |||
{ | |||
var setReturned = new ManualResetEventSlim(); | |||
var inlinedContinuation = this.evt.WaitOneAsync() | |||
.ContinueWith(delegate | |||
{ | |||
// Arrange to synchronously block the continuation until Set() has returned, | |||
// which would deadlock if Set does not return until inlined continuations complete. | |||
Assert.IsTrue(setReturned.Wait(500)); | |||
}); | |||
await Task.Delay(100); | |||
this.evt.Set(); | |||
setReturned.Set(); | |||
Assert.IsTrue(inlinedContinuation.Wait(500)); | |||
} | |||
[TestMethod] | |||
public void WaitAsync_WithCancellationToken() | |||
{ | |||
var cts = new CancellationTokenSource(); | |||
Task waitTask = this.evt.WaitOneAsync(cts.Token); | |||
Assert.IsFalse(waitTask.IsCompleted); | |||
// Cancel the request and ensure that it propagates to the task. | |||
cts.Cancel(); | |||
try | |||
{ | |||
waitTask.GetAwaiter().GetResult(); | |||
Assert.IsTrue(false, "Task was expected to transition to a canceled state."); | |||
} | |||
catch (System.OperationCanceledException ex) | |||
{ | |||
Assert.AreEqual(cts.Token, ex.CancellationToken); | |||
} | |||
// Now set the event and verify that a future waiter gets the signal immediately. | |||
this.evt.Set(); | |||
waitTask = this.evt.WaitOneAsync(); | |||
Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); | |||
} | |||
[TestMethod] | |||
public void WaitAsync_WithCancellationToken_Precanceled() | |||
{ | |||
// We construct our own pre-canceled token so that we can do | |||
// a meaningful identity check later. | |||
var tokenSource = new CancellationTokenSource(); | |||
tokenSource.Cancel(); | |||
var token = tokenSource.Token; | |||
// Verify that a pre-set signal is not reset by a canceled wait request. | |||
this.evt.Set(); | |||
try | |||
{ | |||
this.evt.WaitOneAsync(token).GetAwaiter().GetResult(); | |||
Assert.IsTrue(false, "Task was expected to transition to a canceled state."); | |||
} | |||
catch (OperationCanceledException ex) | |||
{ | |||
Assert.AreEqual(token, ex.CancellationToken); | |||
} | |||
// Verify that the signal was not acquired. | |||
Task waitTask = this.evt.WaitOneAsync(); | |||
Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); | |||
} | |||
[TestMethod] | |||
public async Task WaitAsync_WithTimeout() | |||
{ | |||
Task waitTask = this.evt.WaitOneAsync(TimeSpan.FromMilliseconds(500)); | |||
Assert.IsFalse(waitTask.IsCompleted); | |||
// Cancel the request and ensure that it propagates to the task. | |||
await Task.Delay(1000); | |||
try | |||
{ | |||
waitTask.GetAwaiter().GetResult(); | |||
Assert.IsTrue(false, "Task was expected to transition to a timeout state."); | |||
} | |||
catch (System.TimeoutException) | |||
{ | |||
Assert.IsTrue(true); | |||
} | |||
// Now set the event and verify that a future waiter gets the signal immediately. | |||
this.evt.Set(); | |||
waitTask = this.evt.WaitOneAsync(TimeSpan.FromMilliseconds(500)); | |||
Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); | |||
} | |||
[TestMethod] | |||
public void WaitAsync_Canceled_DoesNotInlineContinuations() | |||
{ | |||
var cts = new CancellationTokenSource(); | |||
var task = this.evt.WaitOneAsync(cts.Token); | |||
var completingActionFinished = new ManualResetEventSlim(); | |||
var continuation = task.ContinueWith( | |||
_ => Assert.IsTrue(completingActionFinished.Wait(500)), | |||
CancellationToken.None, | |||
TaskContinuationOptions.None, | |||
TaskScheduler.Default); | |||
cts.Cancel(); | |||
completingActionFinished.Set(); | |||
// Rethrow the exception if it turned out it deadlocked. | |||
continuation.GetAwaiter().GetResult(); | |||
} | |||
[TestMethod] | |||
public async Task AsyncAutoResetEvent() | |||
{ | |||
var aare = new AsyncAutoResetEvent(); | |||
var increment = 0; | |||
var globalI = 0; | |||
#pragma warning disable 4014 | |||
Task.Run(async () => | |||
#pragma warning restore 4014 | |||
{ | |||
await aare.WaitOneAsync(CancellationToken.None); | |||
globalI += increment; | |||
globalI += 1; | |||
}); | |||
#pragma warning disable 4014 | |||
Task.Run(async () => | |||
#pragma warning restore 4014 | |||
{ | |||
await aare.WaitOneAsync(CancellationToken.None); | |||
globalI += 2; | |||
}); | |||
await Task.Delay(500); | |||
aare.Set(); | |||
await Task.Delay(500); | |||
increment = 1; | |||
aare.Set(); | |||
await Task.Delay(100); | |||
Assert.AreEqual(1, globalI); | |||
Assert.AreEqual(3, globalI); | |||
} | |||
} | |||
} | |||
} |
@@ -21,22 +21,17 @@ namespace MQTTnet.Core.Tests | |||
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed | |||
threads[i] = Task.Run(async () => | |||
{ | |||
await @lock.EnterAsync(CancellationToken.None); | |||
try | |||
using (var releaser = await @lock.LockAsync(CancellationToken.None)) | |||
{ | |||
var localI = globalI; | |||
await Task.Delay(10); // Increase the chance for wrong data. | |||
localI++; | |||
globalI = localI; | |||
} | |||
finally | |||
{ | |||
@lock.Exit(); | |||
} | |||
}); | |||
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed | |||
} | |||
Task.WaitAll(threads); | |||
Assert.AreEqual(ThreadsCount, globalI); | |||
} | |||