diff --git a/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs b/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs index ad20646..0a304e6 100644 --- a/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs @@ -1,26 +1,109 @@ using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; namespace MQTTnet.Internal { - public sealed class AsyncAutoResetEvent : IDisposable + // Inspired from Stephen Toub (https://blogs.msdn.microsoft.com/pfxteam/2012/02/11/building-async-coordination-primitives-part-2-asyncautoresetevent/) and Chris Gillum (https://stackoverflow.com/a/43012490) + public class AsyncAutoResetEvent { - private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(0, 1); + private readonly LinkedList> waiters = new LinkedList>(); - public Task WaitOneAsync(CancellationToken cancellationToken) + private bool isSignaled; + + public AsyncAutoResetEvent(bool signaled = false) { - return _semaphore.WaitAsync(cancellationToken); + this.isSignaled = signaled; } - public void Set() + public Task WaitOneAsync() + { + return this.WaitOneAsync(CancellationToken.None); + } + + public Task WaitOneAsync(TimeSpan timeout) { - _semaphore.Release(); + return this.WaitOneAsync(timeout, CancellationToken.None); } - public void Dispose() + public Task WaitOneAsync(CancellationToken cancellationToken) { - _semaphore?.Dispose(); + return this.WaitOneAsync(Timeout.InfiniteTimeSpan, cancellationToken); + } + + public async Task WaitOneAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + TaskCompletionSource tcs; + + lock (this.waiters) + { + if (this.isSignaled) + { + this.isSignaled = false; + return true; + } + else if (timeout == TimeSpan.Zero) + { + return this.isSignaled; + } + else + { + tcs = new TaskCompletionSource(); + this.waiters.AddLast(tcs); + } + } + + Task winner = await Task.WhenAny(tcs.Task, Task.Delay(timeout, cancellationToken)); + if (winner == tcs.Task) + { + // The task was signaled. + return true; + } + else + { + // We timed-out; remove our reference to the task. + // This is an O(n) operation since waiters is a LinkedList. + lock (this.waiters) + { + bool removed = this.waiters.Remove(tcs); + if (winner.Status == TaskStatus.Canceled) + { + throw new OperationCanceledException(cancellationToken); + } + else + { + throw new TimeoutException(); + } + } + } + } + + public void Set() + { + TaskCompletionSource toRelease = null; + + lock (this.waiters) + { + if (this.waiters.Count > 0) + { + // Signal the first task in the waiters list. + toRelease = this.waiters.First.Value; + this.waiters.RemoveFirst(); + } + else if (!this.isSignaled) + { + // No tasks are pending + this.isSignaled = true; + } + } + + if (toRelease != null) + { + toRelease.SetResult(true); + } } } } diff --git a/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs b/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs index 145e385..d402949 100644 --- a/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs @@ -4,23 +4,32 @@ using System.Threading.Tasks; namespace MQTTnet.Internal { - public sealed class AsyncLock : IDisposable + // From Stephen Toub (https://blogs.msdn.microsoft.com/pfxteam/2012/02/12/building-async-coordination-primitives-part-6-asynclock/) + public sealed class AsyncLock { - private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); + private readonly SemaphoreSlim m_semaphore = new SemaphoreSlim(1, 1); + private readonly Task m_releaser; - public Task EnterAsync(CancellationToken cancellationToken) + public AsyncLock() { - return _semaphore.WaitAsync(cancellationToken); + m_releaser = Task.FromResult((IDisposable)new Releaser(this)); } - public void Exit() + public Task LockAsync(CancellationToken cancellationToken) { - _semaphore.Release(); + var wait = m_semaphore.WaitAsync(cancellationToken); + return wait.IsCompleted ? + m_releaser : + wait.ContinueWith((_, state) => (IDisposable)state, + m_releaser.Result, cancellationToken, + TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); } - public void Dispose() + private sealed class Releaser : IDisposable { - _semaphore?.Dispose(); + private readonly AsyncLock m_toRelease; + internal Releaser(AsyncLock toRelease) { m_toRelease = toRelease; } + public void Dispose() { m_toRelease.m_semaphore.Release(); } } } } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs index fc74d4c..f08e18e 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs @@ -154,7 +154,6 @@ namespace MQTTnet.Server public void Dispose() { - _queueAutoResetEvent?.Dispose(); } } } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs index e50b8c2..e47f3f3 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs @@ -119,8 +119,7 @@ namespace MQTTnet.Server public async Task StopAsync() { - await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { foreach (var session in _sessions) { @@ -129,16 +128,11 @@ namespace MQTTnet.Server _sessions.Clear(); } - finally - { - _sessionsLock.Exit(); - } } public async Task> GetConnectedClientsAsync() { - await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { return _sessions.Where(s => s.Value.IsConnected).Select(s => new ConnectedMqttClient { @@ -149,10 +143,6 @@ namespace MQTTnet.Server PendingApplicationMessages = s.Value.PendingMessagesQueue.Count }).ToList(); } - finally - { - _sessionsLock.Exit(); - } } public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) @@ -165,8 +155,7 @@ namespace MQTTnet.Server if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { if (!_sessions.TryGetValue(clientId, out var session)) { @@ -175,10 +164,6 @@ namespace MQTTnet.Server await session.SubscribeAsync(topicFilters).ConfigureAwait(false); } - finally - { - _sessionsLock.Exit(); - } } public async Task UnsubscribeAsync(string clientId, IList topicFilters) @@ -186,8 +171,7 @@ namespace MQTTnet.Server if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { if (!_sessions.TryGetValue(clientId, out var session)) { @@ -196,15 +180,10 @@ namespace MQTTnet.Server await session.UnsubscribeAsync(topicFilters).ConfigureAwait(false); } - finally - { - _sessionsLock.Exit(); - } } public void Dispose() { - _sessionsLock?.Dispose(); } private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) @@ -226,8 +205,8 @@ namespace MQTTnet.Server private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) { - await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + { var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); if (isSessionPresent) @@ -261,10 +240,6 @@ namespace MQTTnet.Server return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; } - finally - { - _sessionsLock.Exit(); - } } private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) @@ -294,18 +269,13 @@ namespace MQTTnet.Server _logger.Error(exception, "Error while processing application message"); } - await _sessionsLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _sessionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { foreach (var clientSession in _sessions.Values) { await clientSession.EnqueueApplicationMessageAsync(applicationMessage).ConfigureAwait(false); } } - finally - { - _sessionsLock.Exit(); - } } private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttRetainedMessagesManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttRetainedMessagesManager.cs index be0a5b7..e00ff44 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttRetainedMessagesManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttRetainedMessagesManager.cs @@ -29,43 +29,39 @@ namespace MQTTnet.Server return; } - await _messagesLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { - var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); + try + { + var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); - _messages.Clear(); - foreach (var retainedMessage in retainedMessages) + _messages.Clear(); + foreach (var retainedMessage in retainedMessages) + { + _messages[retainedMessage.Topic] = retainedMessage; + } + } + catch (Exception exception) { - _messages[retainedMessage.Topic] = retainedMessage; + _logger.Error(exception, "Unhandled exception while loading retained messages."); } } - catch (Exception exception) - { - _logger.Error(exception, "Unhandled exception while loading retained messages."); - } - finally - { - _messagesLock.Exit(); - } } public async Task HandleMessageAsync(string clientId, MqttApplicationMessage applicationMessage) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - await _messagesLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try - { - await HandleMessageInternalAsync(clientId, applicationMessage); - } - catch (Exception exception) - { - _logger.Error(exception, "Unhandled exception while handling retained messages."); - } - finally + using (var releaser = await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { - _messagesLock.Exit(); + try + { + await HandleMessageInternalAsync(clientId, applicationMessage); + } + catch (Exception exception) + { + _logger.Error(exception, "Unhandled exception while handling retained messages."); + } } } @@ -73,8 +69,7 @@ namespace MQTTnet.Server { var retainedMessages = new List(); - await _messagesLock.EnterAsync(CancellationToken.None).ConfigureAwait(false); - try + using (var releaser = await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { foreach (var retainedMessage in _messages.Values) { @@ -90,17 +85,12 @@ namespace MQTTnet.Server } } } - finally - { - _messagesLock.Exit(); - } return retainedMessages; } public void Dispose() { - _messagesLock?.Dispose(); } private async Task HandleMessageInternalAsync(string clientId, MqttApplicationMessage applicationMessage) diff --git a/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs b/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs index 77275eb..879f925 100644 --- a/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs +++ b/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs @@ -1,34 +1,210 @@ -using System.Threading; -using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Internal; +using System; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet.Core.Tests { [TestClass] + // Inspired from the vs-threading tests (https://github.com/Microsoft/vs-threading/blob/master/src/Microsoft.VisualStudio.Threading.Tests/AsyncAutoResetEventTests.cs) public class AsyncAutoResetEventTests { + private AsyncAutoResetEvent evt; + + public AsyncAutoResetEventTests() + { + this.evt = new AsyncAutoResetEvent(); + } + + [TestMethod] + public async Task SingleThreadedPulse() + { + for (int i = 0; i < 5; i++) + { + var t = this.evt.WaitOneAsync(); + Assert.IsFalse(t.IsCompleted); + this.evt.Set(); + await t; + Assert.IsTrue(t.IsCompleted); + } + } + + [TestMethod] + public async Task MultipleSetOnlySignalsOnce() + { + this.evt.Set(); + this.evt.Set(); + await this.evt.WaitOneAsync(); + var t = this.evt.WaitOneAsync(); + Assert.IsFalse(t.IsCompleted); + await Task.Delay(500); + Assert.IsFalse(t.IsCompleted); + this.evt.Set(); + await t; + Assert.IsTrue(t.IsCompleted); + } + + [TestMethod] + public async Task OrderPreservingQueue() + { + var waiters = new Task[5]; + for (int i = 0; i < waiters.Length; i++) + { + waiters[i] = this.evt.WaitOneAsync(); + } + + for (int i = 0; i < waiters.Length; i++) + { + this.evt.Set(); + await waiters[i]; + } + } + + /// + /// Verifies that inlining continuations do not have to complete execution before Set() returns. + /// + [TestMethod] + public async Task SetReturnsBeforeInlinedContinuations() + { + var setReturned = new ManualResetEventSlim(); + var inlinedContinuation = this.evt.WaitOneAsync() + .ContinueWith(delegate + { + // Arrange to synchronously block the continuation until Set() has returned, + // which would deadlock if Set does not return until inlined continuations complete. + Assert.IsTrue(setReturned.Wait(500)); + }); + await Task.Delay(100); + this.evt.Set(); + setReturned.Set(); + Assert.IsTrue(inlinedContinuation.Wait(500)); + } + + [TestMethod] + public void WaitAsync_WithCancellationToken() + { + var cts = new CancellationTokenSource(); + Task waitTask = this.evt.WaitOneAsync(cts.Token); + Assert.IsFalse(waitTask.IsCompleted); + + // Cancel the request and ensure that it propagates to the task. + cts.Cancel(); + try + { + waitTask.GetAwaiter().GetResult(); + Assert.IsTrue(false, "Task was expected to transition to a canceled state."); + } + catch (System.OperationCanceledException ex) + { + Assert.AreEqual(cts.Token, ex.CancellationToken); + } + + // Now set the event and verify that a future waiter gets the signal immediately. + this.evt.Set(); + waitTask = this.evt.WaitOneAsync(); + Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); + } + + [TestMethod] + public void WaitAsync_WithCancellationToken_Precanceled() + { + // We construct our own pre-canceled token so that we can do + // a meaningful identity check later. + var tokenSource = new CancellationTokenSource(); + tokenSource.Cancel(); + var token = tokenSource.Token; + + // Verify that a pre-set signal is not reset by a canceled wait request. + this.evt.Set(); + try + { + this.evt.WaitOneAsync(token).GetAwaiter().GetResult(); + Assert.IsTrue(false, "Task was expected to transition to a canceled state."); + } + catch (OperationCanceledException ex) + { + Assert.AreEqual(token, ex.CancellationToken); + } + + // Verify that the signal was not acquired. + Task waitTask = this.evt.WaitOneAsync(); + Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); + } + + [TestMethod] + public async Task WaitAsync_WithTimeout() + { + Task waitTask = this.evt.WaitOneAsync(TimeSpan.FromMilliseconds(500)); + Assert.IsFalse(waitTask.IsCompleted); + + // Cancel the request and ensure that it propagates to the task. + await Task.Delay(1000); + try + { + waitTask.GetAwaiter().GetResult(); + Assert.IsTrue(false, "Task was expected to transition to a timeout state."); + } + catch (System.TimeoutException) + { + Assert.IsTrue(true); + } + + // Now set the event and verify that a future waiter gets the signal immediately. + this.evt.Set(); + waitTask = this.evt.WaitOneAsync(TimeSpan.FromMilliseconds(500)); + Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); + } + + [TestMethod] + public void WaitAsync_Canceled_DoesNotInlineContinuations() + { + var cts = new CancellationTokenSource(); + var task = this.evt.WaitOneAsync(cts.Token); + + var completingActionFinished = new ManualResetEventSlim(); + var continuation = task.ContinueWith( + _ => Assert.IsTrue(completingActionFinished.Wait(500)), + CancellationToken.None, + TaskContinuationOptions.None, + TaskScheduler.Default); + + cts.Cancel(); + completingActionFinished.Set(); + + // Rethrow the exception if it turned out it deadlocked. + continuation.GetAwaiter().GetResult(); + } + [TestMethod] public async Task AsyncAutoResetEvent() { var aare = new AsyncAutoResetEvent(); - var increment = 0; var globalI = 0; #pragma warning disable 4014 Task.Run(async () => #pragma warning restore 4014 { await aare.WaitOneAsync(CancellationToken.None); - globalI += increment; + globalI += 1; }); +#pragma warning disable 4014 + Task.Run(async () => +#pragma warning restore 4014 + { + await aare.WaitOneAsync(CancellationToken.None); + globalI += 2; + }); + + await Task.Delay(500); + aare.Set(); await Task.Delay(500); - increment = 1; aare.Set(); await Task.Delay(100); - Assert.AreEqual(1, globalI); + Assert.AreEqual(3, globalI); } } -} +} \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs index 56c7050..43f2486 100644 --- a/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs +++ b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs @@ -21,22 +21,17 @@ namespace MQTTnet.Core.Tests #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed threads[i] = Task.Run(async () => { - await @lock.EnterAsync(CancellationToken.None); - try + using (var releaser = await @lock.LockAsync(CancellationToken.None)) { var localI = globalI; await Task.Delay(10); // Increase the chance for wrong data. localI++; globalI = localI; } - finally - { - @lock.Exit(); - } }); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } - + Task.WaitAll(threads); Assert.AreEqual(ThreadsCount, globalI); }