diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs index 2e04c90..8384bca 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs @@ -290,18 +290,27 @@ namespace MQTTnet.Client private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - _sendTracker.Restart(); - return _adapter.SendPacketsAsync(_options.CommunicationTimeout, new[] { packet }, cancellationToken); + return SendAsync(new[] { packet }, cancellationToken); } private Task SendAsync(IEnumerable packets, CancellationToken cancellationToken) { + if (cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException(); + } + _sendTracker.Restart(); return _adapter.SendPacketsAsync(_options.CommunicationTimeout, packets, cancellationToken); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket { + if (cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException(); + } + _sendTracker.Restart(); ushort identifier = 0; @@ -528,7 +537,7 @@ namespace MQTTnet.Client private static async Task WaitForTaskAsync(Task task, Task sender) { - if (task == sender) + if (task == sender || task == null) { return; } diff --git a/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs b/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs new file mode 100644 index 0000000..ad20646 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs @@ -0,0 +1,26 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Internal +{ + public sealed class AsyncAutoResetEvent : IDisposable + { + private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(0, 1); + + public Task WaitOneAsync(CancellationToken cancellationToken) + { + return _semaphore.WaitAsync(cancellationToken); + } + + public void Set() + { + _semaphore.Release(); + } + + public void Dispose() + { + _semaphore?.Dispose(); + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs b/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs new file mode 100644 index 0000000..145e385 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs @@ -0,0 +1,26 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Internal +{ + public sealed class AsyncLock : IDisposable + { + private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); + + public Task EnterAsync(CancellationToken cancellationToken) + { + return _semaphore.WaitAsync(cancellationToken); + } + + public void Exit() + { + _semaphore.Release(); + } + + public void Dispose() + { + _semaphore?.Dispose(); + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs index f56a8c9..1d00528 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs @@ -5,6 +5,7 @@ using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; +using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -13,7 +14,7 @@ namespace MQTTnet.Server public sealed class MqttClientPendingMessagesQueue : IDisposable { private readonly ConcurrentQueue _queue = new ConcurrentQueue(); - private readonly SemaphoreSlim _queueWaitSemaphore = new SemaphoreSlim(0); + private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); private readonly IMqttServerOptions _options; private readonly MqttClientSession _clientSession; private readonly IMqttNetLogger _logger; @@ -54,7 +55,7 @@ namespace MQTTnet.Server if (packet == null) throw new ArgumentNullException(nameof(packet)); _queue.Enqueue(packet); - _queueWaitSemaphore.Release(); + _queueAutoResetEvent.Set(); _logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId); } @@ -82,7 +83,7 @@ namespace MQTTnet.Server MqttBasePacket packet = null; try { - await _queueWaitSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false); if (!_queue.TryDequeue(out packet)) { throw new InvalidOperationException(); // should not happen @@ -120,8 +121,8 @@ namespace MQTTnet.Server if (publishPacket.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { publishPacket.Dup = true; - _queue.Enqueue(packet); - _queueWaitSemaphore.Release(); + + Enqueue(publishPacket); } } @@ -134,7 +135,7 @@ namespace MQTTnet.Server public void Dispose() { - _queueWaitSemaphore?.Dispose(); + _queueAutoResetEvent?.Dispose(); } } } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs index 0e61b10..183ab8e 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs @@ -134,7 +134,7 @@ namespace MQTTnet.Server { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - var result = await SubscriptionsManager.CheckSubscriptionsAsync(applicationMessage); + var result = await SubscriptionsManager.CheckSubscriptionsAsync(applicationMessage).ConfigureAwait(false); if (!result.IsSubscribed) { return; @@ -155,10 +155,10 @@ namespace MQTTnet.Server { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - await SubscriptionsManager.SubscribeAsync(new MqttSubscribePacket + SubscriptionsManager.Subscribe(new MqttSubscribePacket { TopicFilters = topicFilters - }).ConfigureAwait(false); + }); await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false); } @@ -167,10 +167,12 @@ namespace MQTTnet.Server { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return SubscriptionsManager.UnsubscribeAsync(new MqttUnsubscribePacket + SubscriptionsManager.Unsubscribe(new MqttUnsubscribePacket { TopicFilters = topicFilters }); + + return Task.FromResult(0); } public void Dispose() @@ -288,7 +290,7 @@ namespace MQTTnet.Server private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { - var subscribeResult = await SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); + var subscribeResult = SubscriptionsManager.Subscribe(subscribePacket); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { subscribeResult.ResponsePacket }, cancellationToken).ConfigureAwait(false); if (subscribeResult.CloseConnection) @@ -302,7 +304,7 @@ namespace MQTTnet.Server private async Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { - var unsubscribeResult = await SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); + var unsubscribeResult = SubscriptionsManager.Unsubscribe(unsubscribePacket); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { unsubscribeResult }, cancellationToken); } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs index a49723a..dfb9463 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading; @@ -10,7 +11,7 @@ namespace MQTTnet.Server { public sealed class MqttClientSubscriptionsManager : IDisposable { - private readonly Dictionary _subscriptions = new Dictionary(); + private readonly ConcurrentDictionary _subscriptions = new ConcurrentDictionary(); private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly IMqttServerOptions _options; private readonly MqttServer _server; @@ -23,7 +24,7 @@ namespace MQTTnet.Server _server = server; } - public async Task SubscribeAsync(MqttSubscribePacket subscribePacket) + public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket) { if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); @@ -37,57 +38,41 @@ namespace MQTTnet.Server CloseConnection = false }; - await _semaphore.WaitAsync().ConfigureAwait(false); - try + foreach (var topicFilter in subscribePacket.TopicFilters) { - foreach (var topicFilter in subscribePacket.TopicFilters) + var interceptorContext = InterceptSubscribe(topicFilter); + if (!interceptorContext.AcceptSubscription) { - var interceptorContext = InterceptSubscribe(topicFilter); - if (!interceptorContext.AcceptSubscription) - { - result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure); - } - else - { - result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel)); - } + result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure); + } + else + { + result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel)); + } - if (interceptorContext.CloseConnection) - { - result.CloseConnection = true; - } + if (interceptorContext.CloseConnection) + { + result.CloseConnection = true; + } - if (interceptorContext.AcceptSubscription) - { - _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; - _server.OnClientSubscribedTopic(_clientId, topicFilter); - } + if (interceptorContext.AcceptSubscription) + { + _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + _server.OnClientSubscribedTopic(_clientId, topicFilter); } } - finally - { - _semaphore.Release(); - } return result; } - public async Task UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket) + public MqttUnsubAckPacket Unsubscribe(MqttUnsubscribePacket unsubscribePacket) { if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket)); - await _semaphore.WaitAsync().ConfigureAwait(false); - try + foreach (var topicFilter in unsubscribePacket.TopicFilters) { - foreach (var topicFilter in unsubscribePacket.TopicFilters) - { - _subscriptions.Remove(topicFilter); - _server.OnClientUnsubscribedTopic(_clientId, topicFilter); - } - } - finally - { - _semaphore.Release(); + _subscriptions.TryRemove(topicFilter, out _); + _server.OnClientUnsubscribedTopic(_clientId, topicFilter); } return new MqttUnsubAckPacket diff --git a/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs b/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs new file mode 100644 index 0000000..77275eb --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs @@ -0,0 +1,34 @@ +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Internal; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class AsyncAutoResetEventTests + { + [TestMethod] + public async Task AsyncAutoResetEvent() + { + var aare = new AsyncAutoResetEvent(); + + var increment = 0; + var globalI = 0; +#pragma warning disable 4014 + Task.Run(async () => +#pragma warning restore 4014 + { + await aare.WaitOneAsync(CancellationToken.None); + globalI += increment; + }); + + await Task.Delay(500); + increment = 1; + aare.Set(); + await Task.Delay(100); + + Assert.AreEqual(1, globalI); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs new file mode 100644 index 0000000..56c7050 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs @@ -0,0 +1,44 @@ +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Internal; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class AsyncLockTests + { + [TestMethod] + public void AsyncLock() + { + const int ThreadsCount = 10; + + var threads = new Task[ThreadsCount]; + var @lock = new AsyncLock(); + var globalI = 0; + for (var i = 0; i < ThreadsCount; i++) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + threads[i] = Task.Run(async () => + { + await @lock.EnterAsync(CancellationToken.None); + try + { + var localI = globalI; + await Task.Delay(10); // Increase the chance for wrong data. + localI++; + globalI = localI; + } + finally + { + @lock.Exit(); + } + }); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + Task.WaitAll(threads); + Assert.AreEqual(ThreadsCount, globalI); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 399e71e..bcd9fb8 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Diagnostics; @@ -58,8 +59,8 @@ namespace MQTTnet.Core.Tests await s.StartAsync(new MqttServerOptions()); var willMessage = new MqttApplicationMessageBuilder().WithTopic("My/last/will").WithAtMostOnceQoS().Build(); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2", willMessage); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2", willMessage); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic("#").Build()); @@ -90,8 +91,8 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); @@ -149,7 +150,7 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; @@ -167,6 +168,40 @@ namespace MQTTnet.Core.Tests Assert.AreEqual(1, receivedMessagesCount); } + [TestMethod] + public async Task MqttServer_RetainedMessagesFlow() + { + var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); + var serverAdapter = new TestMqttServerAdapter(); + var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); + await s.StartAsync(new MqttServerOptions()); + var c1 = await serverAdapter.ConnectTestClient("c1"); + await c1.PublishAsync(retainedMessage); + Thread.Sleep(500); + await c1.DisconnectAsync(); + Thread.Sleep(500); + + var receivedMessages = 0; + var c2 = await serverAdapter.ConnectTestClient("c2"); + c2.ApplicationMessageReceived += (_, e) => + { + receivedMessages++; + }; + + for (var i = 0; i < 5; i++) + { + await c2.UnsubscribeAsync("r"); + await Task.Delay(500); + Assert.AreEqual(i, receivedMessages); + + await c2.SubscribeAsync("r"); + await Task.Delay(500); + Assert.AreEqual(i + 1, receivedMessages); + } + + await c2.DisconnectAsync(); + } + [TestMethod] public async Task MqttServer_NoRetainedMessage() { @@ -179,11 +214,11 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).Build()); await c1.DisconnectAsync(); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); @@ -208,11 +243,11 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); await c1.DisconnectAsync(); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); @@ -237,14 +272,14 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[0]).WithRetainFlag().Build()); await c1.DisconnectAsync(); - - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - + await Task.Delay(200); await c2.SubscribeAsync(new TopicFilter("retained", MqttQualityOfServiceLevel.AtMostOnce)); await Task.Delay(500); @@ -271,7 +306,7 @@ namespace MQTTnet.Core.Tests await s.StartAsync(options); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); await c1.DisconnectAsync(); @@ -291,7 +326,7 @@ namespace MQTTnet.Core.Tests var options = new MqttServerOptions { Storage = storage }; await s.StartAsync(options); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); @@ -322,8 +357,8 @@ namespace MQTTnet.Core.Tests await s.StartAsync(options); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("test").Build()); var isIntercepted = false; @@ -357,8 +392,8 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c1.ApplicationMessageReceived += (_, e) => { @@ -412,8 +447,8 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs index 1346415..6f16ab5 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs @@ -18,7 +18,7 @@ namespace MQTTnet.Core.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -39,7 +39,7 @@ namespace MQTTnet.Core.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce)); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -61,7 +61,7 @@ namespace MQTTnet.Core.Tests sp.TopicFilters.Add(new TopicFilter("#", MqttQualityOfServiceLevel.AtMostOnce)); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtLeastOnce)); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -82,7 +82,7 @@ namespace MQTTnet.Core.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -101,7 +101,7 @@ namespace MQTTnet.Core.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -113,7 +113,7 @@ namespace MQTTnet.Core.Tests var up = new MqttUnsubscribePacket(); up.TopicFilters.Add("A/B/C"); - sm.UnsubscribeAsync(up).Wait(); + sm.Unsubscribe(up); Assert.IsFalse(sm.CheckSubscriptionsAsync(pp).Result.IsSubscribed); } diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index 808e08a..837221b 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -43,11 +43,30 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) + public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); - return Task.Run(() => + if (timeout > TimeSpan.Zero) + { + using (var timeoutCts = new CancellationTokenSource(timeout)) + using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) + { + return await Task.Run(() => + { + try + { + return _incomingPackets.Take(cts.Token); + } + catch + { + return null; + } + }, cts.Token); + } + } + + return await Task.Run(() => { try { diff --git a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs index 8b9d0ba..2b24f42 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Client; @@ -11,7 +12,7 @@ namespace MQTTnet.Core.Tests { public event EventHandler ClientAccepted; - public async Task ConnectTestClient(IMqttServer server, string clientId, MqttApplicationMessage willMessage = null) + public async Task ConnectTestClient(string clientId, MqttApplicationMessage willMessage = null) { var adapterA = new TestMqttCommunicationAdapter(); var adapterB = new TestMqttCommunicationAdapter(); @@ -22,8 +23,6 @@ namespace MQTTnet.Core.Tests new TestMqttCommunicationAdapterFactory(adapterA), new MqttNetLogger()); - var connected = WaitForClientToConnect(server, clientId); - FireClientAcceptedEvent(adapterB); var options = new MqttClientOptions @@ -34,29 +33,11 @@ namespace MQTTnet.Core.Tests }; await client.ConnectAsync(options); - await connected; + SpinWait.SpinUntil(() => client.IsConnected); return client; } - private static Task WaitForClientToConnect(IMqttServer s, string clientId) - { - var tcs = new TaskCompletionSource(); - - void Handler(object sender, Server.MqttClientConnectedEventArgs args) - { - if (args.Client.ClientId == clientId) - { - s.ClientConnected -= Handler; - tcs.SetResult(null); - } - } - - s.ClientConnected += Handler; - - return tcs.Task; - } - private void FireClientAcceptedEvent(IMqttChannelAdapter adapter) { ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(adapter));