diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs index add5f76..9dc1ba4 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs @@ -19,12 +19,22 @@ namespace MQTTnet.Extensions.ManagedClient public class ManagedMqttClient : IManagedMqttClient { private readonly BlockingQueue _messageQueue = new BlockingQueue(); + + /// + /// The subscriptions are managed in 2 separate buckets: + /// and are processed during normal operation + /// and are moved to the when they get processed. They can be accessed by + /// any thread and are therefore mutex'ed. get sent to the broker + /// at reconnect and are solely owned by . + /// + private readonly Dictionary _reconnectSubscriptions = new Dictionary(); private readonly Dictionary _subscriptions = new Dictionary(); private readonly HashSet _unsubscriptions = new HashSet(); + private readonly SemaphoreSlim _subscriptionsQueuedSignal = new SemaphoreSlim(0); private readonly IMqttClient _mqttClient; private readonly IMqttNetChildLogger _logger; - + private readonly AsyncLock _messageQueueLock = new AsyncLock(); private CancellationTokenSource _connectionCancellationToken; @@ -34,7 +44,6 @@ namespace MQTTnet.Extensions.ManagedClient private ManagedMqttClientStorageManager _storageManager; private bool _disposed; - private bool _subscriptionsNotPushed; public ManagedMqttClient(IMqttClient mqttClient, IMqttNetChildLogger logger) { @@ -169,7 +178,7 @@ namespace MQTTnet.Extensions.ManagedClient } _messageQueue.Enqueue(applicationMessage); - + if (_storageManager != null) { if (removedMessage != null) @@ -206,9 +215,10 @@ namespace MQTTnet.Extensions.ManagedClient foreach (var topicFilter in topicFilters) { _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; - _subscriptionsNotPushed = true; + _unsubscriptions.Remove(topicFilter.Topic); } } + _subscriptionsQueuedSignal.Release(); return Task.FromResult(0); } @@ -223,13 +233,11 @@ namespace MQTTnet.Extensions.ManagedClient { foreach (var topic in topics) { - if (_subscriptions.Remove(topic)) - { - _unsubscriptions.Add(topic); - _subscriptionsNotPushed = true; - } + _subscriptions.Remove(topic); + _unsubscriptions.Add(topic); } } + _subscriptionsQueuedSignal.Release(); return Task.FromResult(0); } @@ -255,6 +263,7 @@ namespace MQTTnet.Extensions.ManagedClient _messageQueue.Dispose(); _messageQueueLock.Dispose(); _mqttClient.Dispose(); + _subscriptionsQueuedSignal.Dispose(); } private void ThrowIfDisposed() @@ -296,6 +305,12 @@ namespace MQTTnet.Extensions.ManagedClient _logger.Info("Stopped"); } + _reconnectSubscriptions.Clear(); + lock (_subscriptions) + { + _subscriptions.Clear(); + _unsubscriptions.Clear(); + } } } @@ -311,16 +326,16 @@ namespace MQTTnet.Extensions.ManagedClient return; } - if (connectionState == ReconnectionResult.Reconnected || _subscriptionsNotPushed) + if (connectionState == ReconnectionResult.Reconnected) { - await SynchronizeSubscriptionsAsync().ConfigureAwait(false); + await PublishReconnectSubscriptionsAsync().ConfigureAwait(false); StartPublishing(); return; } if (connectionState == ReconnectionResult.StillConnected) { - await Task.Delay(Options.ConnectionCheckInterval, cancellationToken).ConfigureAwait(false); + await PublishSubscriptionsAsync(Options.ConnectionCheckInterval, cancellationToken).ConfigureAwait(false); } } catch (OperationCanceledException) @@ -390,7 +405,7 @@ namespace MQTTnet.Extensions.ManagedClient // it from the queue. If not, that means this.PublishAsync has already // removed it, in which case we don't want to do anything. _messageQueue.RemoveFirst(i => i.Id.Equals(message.Id)); - + if (_storageManager != null) { await _storageManager.RemoveAsync(message).ConfigureAwait(false); @@ -415,7 +430,7 @@ namespace MQTTnet.Extensions.ManagedClient using (await _messageQueueLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) //lock to avoid conflict with this.PublishAsync { _messageQueue.RemoveFirst(i => i.Id.Equals(message.Id)); - + if (_storageManager != null) { await _storageManager.RemoveAsync(message).ConfigureAwait(false); @@ -439,50 +454,84 @@ namespace MQTTnet.Extensions.ManagedClient } } - private async Task SynchronizeSubscriptionsAsync() + private async Task PublishSubscriptionsAsync(TimeSpan timeout, CancellationToken cancellationToken) { - _logger.Info("Synchronizing subscriptions"); + var endTime = DateTime.UtcNow + timeout; + while (await _subscriptionsQueuedSignal.WaitAsync(GetRemainingTime(endTime), cancellationToken).ConfigureAwait(false)) + { + List subscriptions; + HashSet unsubscriptions; - List subscriptions; - HashSet unsubscriptions; + lock (_subscriptions) + { + subscriptions = _subscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }).ToList(); + _subscriptions.Clear(); + unsubscriptions = new HashSet(_unsubscriptions); + _unsubscriptions.Clear(); + } - lock (_subscriptions) - { - subscriptions = _subscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }).ToList(); + if (!subscriptions.Any() && !unsubscriptions.Any()) + { + continue; + } - unsubscriptions = new HashSet(_unsubscriptions); - _unsubscriptions.Clear(); + _logger.Info("Publishing subscriptions"); - _subscriptionsNotPushed = false; - } + foreach (var unsubscription in unsubscriptions) + { + _reconnectSubscriptions.Remove(unsubscription); + } - if (!subscriptions.Any() && !unsubscriptions.Any()) - { - return; - } + foreach (var subscription in subscriptions) + { + _reconnectSubscriptions[subscription.Topic] = subscription.QualityOfServiceLevel; + } - try - { - if (unsubscriptions.Any()) + try { - await _mqttClient.UnsubscribeAsync(unsubscriptions.ToArray()).ConfigureAwait(false); + if (unsubscriptions.Any()) + { + await _mqttClient.UnsubscribeAsync(unsubscriptions.ToArray()).ConfigureAwait(false); + } + + if (subscriptions.Any()) + { + await _mqttClient.SubscribeAsync(subscriptions.ToArray()).ConfigureAwait(false); + } } + catch (Exception exception) + { + await HandleSubscriptionExceptionAsync(exception).ConfigureAwait(false); + } + } + } - if (subscriptions.Any()) + private async Task PublishReconnectSubscriptionsAsync() + { + _logger.Info("Publishing subscriptions at reconnect"); + + try + { + if (_reconnectSubscriptions.Any()) { + var subscriptions = _reconnectSubscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }); await _mqttClient.SubscribeAsync(subscriptions.ToArray()).ConfigureAwait(false); } } catch (Exception exception) { - _logger.Warning(exception, "Synchronizing subscriptions failed."); - _subscriptionsNotPushed = true; + await HandleSubscriptionExceptionAsync(exception).ConfigureAwait(false); + } + } - var synchronizingSubscriptionsFailedHandler = SynchronizingSubscriptionsFailedHandler; - if (SynchronizingSubscriptionsFailedHandler != null) - { - await synchronizingSubscriptionsFailedHandler.HandleSynchronizingSubscriptionsFailedAsync(new ManagedProcessFailedEventArgs(exception)).ConfigureAwait(false); - } + private async Task HandleSubscriptionExceptionAsync(Exception exception) + { + _logger.Warning(exception, "Synchronizing subscriptions failed."); + + var synchronizingSubscriptionsFailedHandler = SynchronizingSubscriptionsFailedHandler; + if (SynchronizingSubscriptionsFailedHandler != null) + { + await synchronizingSubscriptionsFailedHandler.HandleSynchronizingSubscriptionsFailedAsync(new ManagedProcessFailedEventArgs(exception)).ConfigureAwait(false); } } @@ -509,7 +558,7 @@ namespace MQTTnet.Extensions.ManagedClient return ReconnectionResult.NotConnected; } } - + private void StartPublishing() { if (_publishingCancellationToken != null) @@ -536,5 +585,11 @@ namespace MQTTnet.Extensions.ManagedClient _connectionCancellationToken?.Dispose(); _connectionCancellationToken = null; } + + private TimeSpan GetRemainingTime(DateTime endTime) + { + var remainingTime = endTime - DateTime.UtcNow; + return remainingTime < TimeSpan.Zero ? TimeSpan.Zero : remainingTime; + } } } diff --git a/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs index 0aeea6d..1cf9ab9 100644 --- a/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs @@ -1,10 +1,13 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; using MQTTnet.Client.Connecting; using MQTTnet.Client.Options; +using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; using MQTTnet.Extensions.ManagedClient; using MQTTnet.Server; @@ -95,21 +98,20 @@ namespace MQTTnet.Tests var clientOptions = new MqttClientOptionsBuilder() .WithTcpServer("localhost", testEnvironment.ServerPort); - TaskCompletionSource connected = new TaskCompletionSource(); - managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => { connected.SetResult(true);}); + var connected = GetConnectedTask(managedClient); await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder() .WithClientOptions(clientOptions) .Build()); - await connected.Task; + await connected; await managedClient.StopAsync(); Assert.AreEqual(0, (await server.GetClientStatusAsync()).Count); } } - + [TestMethod] public async Task Storage_Queue_Drains() { @@ -127,12 +129,7 @@ namespace MQTTnet.Tests .WithTcpServer("localhost", testEnvironment.ServerPort); var storage = new ManagedMqttClientTestStorage(); - TaskCompletionSource connected = new TaskCompletionSource(); - managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => - { - managedClient.ConnectedHandler = null; - connected.SetResult(true); - }); + var connected = GetConnectedTask(managedClient); await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder() .WithClientOptions(clientOptions) @@ -140,7 +137,7 @@ namespace MQTTnet.Tests .WithAutoReconnectDelay(System.TimeSpan.FromSeconds(5)) .Build()); - await connected.Task; + await connected; await testEnvironment.Server.StopAsync(); @@ -151,17 +148,12 @@ namespace MQTTnet.Tests //in storage at this point (i.e. no waiting). Assert.AreEqual(1, storage.GetMessageCount()); - connected = new TaskCompletionSource(); - managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => - { - managedClient.ConnectedHandler = null; - connected.SetResult(true); - }); + connected = GetConnectedTask(managedClient); await testEnvironment.Server.StartAsync(new MqttServerOptionsBuilder() .WithDefaultEndpointPort(testEnvironment.ServerPort).Build()); - await connected.Task; + await connected; //Wait 500ms here so the client has time to publish the queued message await Task.Delay(500); @@ -171,8 +163,235 @@ namespace MQTTnet.Tests await managedClient.StopAsync(); } } + + [TestMethod] + public async Task Subscriptions_And_Unsubscriptions_Are_Made_And_Reestablished_At_Reconnect() + { + using (var testEnvironment = new TestEnvironment()) + { + var unmanagedClient = testEnvironment.CreateClient(); + var managedClient = await CreateManagedClientAsync(testEnvironment, unmanagedClient); + + var received = SetupReceivingOfMessages(managedClient, 2); + + // Perform some opposing subscriptions and unsubscriptions to verify + // that these conflicting subscriptions are handled correctly + await managedClient.SubscribeAsync("keptSubscribed"); + await managedClient.SubscribeAsync("subscribedThenUnsubscribed"); + + await managedClient.UnsubscribeAsync("subscribedThenUnsubscribed"); + await managedClient.UnsubscribeAsync("unsubscribedThenSubscribed"); + + await managedClient.SubscribeAsync("unsubscribedThenSubscribed"); + + //wait a bit for the subscriptions to become established before the messages are published + await Task.Delay(500); + + var sendingClient = await testEnvironment.ConnectClientAsync(); + + async Task PublishMessages() + { + await sendingClient.PublishAsync("keptSubscribed", new byte[] { 1 }); + await sendingClient.PublishAsync("subscribedThenUnsubscribed", new byte[] { 1 }); + await sendingClient.PublishAsync("unsubscribedThenSubscribed", new byte[] { 1 }); + } + + await PublishMessages(); + + async Task AssertMessagesReceived() + { + var messages = await received; + Assert.AreEqual("keptSubscribed", messages[0].Topic); + Assert.AreEqual("unsubscribedThenSubscribed", messages[1].Topic); + } + + await AssertMessagesReceived(); + + var connected = GetConnectedTask(managedClient); + + await unmanagedClient.DisconnectAsync(); + + // the managed client has to reconnect by itself + await connected; + + // wait a bit so that the managed client can reestablish the subscriptions + await Task.Delay(500); + + received = SetupReceivingOfMessages(managedClient, 2); + + await PublishMessages(); + + // and then the same subscriptions need to exist again + await AssertMessagesReceived(); + } + } + + // This case also serves as a regression test for the previous behavior which re-published + // each and every existing subscriptions with every new subscription that was made + // (causing performance problems and having the visible symptom of retained messages being received again) + [TestMethod] + public async Task Subscriptions_Subscribe_Only_New_Subscriptions() + { + using (var testEnvironment = new TestEnvironment()) + { + var managedClient = await CreateManagedClientAsync(testEnvironment); + + var sendingClient = await testEnvironment.ConnectClientAsync(); + + await managedClient.SubscribeAsync("topic"); + + //wait a bit for the subscription to become established + await Task.Delay(500); + + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + + var messages = await SetupReceivingOfMessages(managedClient, 1); + + Assert.AreEqual(1, messages.Count); + Assert.AreEqual("topic", messages.Single().Topic); + + await managedClient.SubscribeAsync("anotherTopic"); + + await Task.Delay(500); + + // The subscription of the other topic must not trigger a re-subscription of the existing topic + // (and thus renewed receiving of the retained message) + Assert.AreEqual(1, messages.Count); + } + } + + // This case also serves as a regression test for the previous behavior + // that subscriptions were only published at the ConnectionCheckInterval + [TestMethod] + public async Task Subscriptions_Are_Published_Immediately() + { + using (var testEnvironment = new TestEnvironment()) + { + // Use a long connection check interval to verify that the subscriptions + // do not depend on the connection check interval anymore + var connectionCheckInterval = TimeSpan.FromSeconds(10); + var managedClient = await CreateManagedClientAsync(testEnvironment, null, connectionCheckInterval); + var sendingClient = await testEnvironment.ConnectClientAsync(); + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + + await managedClient.SubscribeAsync("topic"); + + var subscribeTime = DateTime.UtcNow; + + var messages = await SetupReceivingOfMessages(managedClient, 1); + + var elapsed = DateTime.UtcNow - subscribeTime; + Assert.IsTrue(elapsed < TimeSpan.FromSeconds(1), $"Subscriptions must be activated immediately, this one took {elapsed}"); + Assert.AreEqual(messages.Single().Topic, "topic"); + } + } + + [TestMethod] + public async Task Subscriptions_Are_Cleared_At_Logout() + { + using (var testEnvironment = new TestEnvironment()) + { + var managedClient = await CreateManagedClientAsync(testEnvironment); + + var sendingClient = await testEnvironment.ConnectClientAsync(); + await sendingClient.PublishAsync(new MqttApplicationMessage + { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + + // Wait a bit for the retained message to be available + await Task.Delay(500); + + await managedClient.SubscribeAsync("topic"); + + await SetupReceivingOfMessages(managedClient, 1); + + await managedClient.StopAsync(); + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost", testEnvironment.ServerPort); + await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder() + .WithClientOptions(clientOptions) + .WithAutoReconnectDelay(TimeSpan.FromSeconds(1)) + .Build()); + + var messages = new List(); + managedClient.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(r => + { + messages.Add(r.ApplicationMessage); + }); + + await Task.Delay(500); + + // After reconnect and then some delay, the retained message must not be received, + // showing that the subscriptions were cleared + Assert.AreEqual(0, messages.Count); + } + } + + private async Task CreateManagedClientAsync( + TestEnvironment testEnvironment, + IMqttClient underlyingClient = null, + TimeSpan? connectionCheckInterval = null) + { + await testEnvironment.StartServerAsync(); + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost", testEnvironment.ServerPort); + + var managedOptions = new ManagedMqttClientOptionsBuilder() + .WithClientOptions(clientOptions) + .Build(); + + // Use a short connection check interval so that subscription operations are performed quickly + // in order to verify against a previous implementation that performed subscriptions only + // at connection check intervals + managedOptions.ConnectionCheckInterval = connectionCheckInterval ?? TimeSpan.FromSeconds(0.1); + + var managedClient = + new ManagedMqttClient(underlyingClient ?? testEnvironment.CreateClient(), new MqttNetLogger().CreateChildLogger()); + + var connected = GetConnectedTask(managedClient); + + await managedClient.StartAsync(managedOptions); + + await connected; + + return managedClient; + } + + /// + /// Returns a task that will finish when the has connected + /// + private Task GetConnectedTask(ManagedMqttClient managedClient) + { + TaskCompletionSource connected = new TaskCompletionSource(); + managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => + { + managedClient.ConnectedHandler = null; + connected.SetResult(true); + }); + return connected.Task; + } + + /// + /// Returns a task that will return the messages received on + /// when have been received + /// + private Task> SetupReceivingOfMessages(ManagedMqttClient managedClient, int expectedNumberOfMessages) + { + var receivedMessages = new List(); + var allReceived = new TaskCompletionSource>(); + managedClient.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(r => + { + receivedMessages.Add(r.ApplicationMessage); + if (receivedMessages.Count == expectedNumberOfMessages) + { + allReceived.SetResult(receivedMessages); + } + }); + return allReceived.Task; + } } - + public class ManagedMqttClientTestStorage : IManagedMqttClientStorage { private IList _messages = null;