浏览代码

Merge pull request #776 from cstichlberger/branches/fix_managed_client_subscription_issues

Separate current and reconnect subscriptions in managed client
release/3.x.x
Christian 5 年前
committed by GitHub
父节点
当前提交
a495668f2b
找不到此签名对应的密钥 GPG 密钥 ID: 4AEE18F83AFDEB23
共有 2 个文件被更改,包括 336 次插入62 次删除
  1. +97
    -42
      Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs
  2. +239
    -20
      Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs

+ 97
- 42
Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs 查看文件

@@ -19,12 +19,22 @@ namespace MQTTnet.Extensions.ManagedClient
public class ManagedMqttClient : IManagedMqttClient
{
private readonly BlockingQueue<ManagedMqttApplicationMessage> _messageQueue = new BlockingQueue<ManagedMqttApplicationMessage>();

/// <summary>
/// The subscriptions are managed in 2 separate buckets:
/// <see cref="_subscriptions"/> and <see cref="_unsubscriptions"/> are processed during normal operation
/// and are moved to the <see cref="_reconnectSubscriptions"/> when they get processed. They can be accessed by
/// any thread and are therefore mutex'ed. <see cref="_reconnectSubscriptions"/> get sent to the broker
/// at reconnect and are solely owned by <see cref="MaintainConnectionAsync"/>.
/// </summary>
private readonly Dictionary<string, MqttQualityOfServiceLevel> _reconnectSubscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
private readonly HashSet<string> _unsubscriptions = new HashSet<string>();
private readonly SemaphoreSlim _subscriptionsQueuedSignal = new SemaphoreSlim(0);

private readonly IMqttClient _mqttClient;
private readonly IMqttNetChildLogger _logger;
private readonly AsyncLock _messageQueueLock = new AsyncLock();

private CancellationTokenSource _connectionCancellationToken;
@@ -34,7 +44,6 @@ namespace MQTTnet.Extensions.ManagedClient
private ManagedMqttClientStorageManager _storageManager;

private bool _disposed;
private bool _subscriptionsNotPushed;

public ManagedMqttClient(IMqttClient mqttClient, IMqttNetChildLogger logger)
{
@@ -169,7 +178,7 @@ namespace MQTTnet.Extensions.ManagedClient
}

_messageQueue.Enqueue(applicationMessage);
if (_storageManager != null)
{
if (removedMessage != null)
@@ -206,9 +215,10 @@ namespace MQTTnet.Extensions.ManagedClient
foreach (var topicFilter in topicFilters)
{
_subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
_subscriptionsNotPushed = true;
_unsubscriptions.Remove(topicFilter.Topic);
}
}
_subscriptionsQueuedSignal.Release();

return Task.FromResult(0);
}
@@ -223,13 +233,11 @@ namespace MQTTnet.Extensions.ManagedClient
{
foreach (var topic in topics)
{
if (_subscriptions.Remove(topic))
{
_unsubscriptions.Add(topic);
_subscriptionsNotPushed = true;
}
_subscriptions.Remove(topic);
_unsubscriptions.Add(topic);
}
}
_subscriptionsQueuedSignal.Release();

return Task.FromResult(0);
}
@@ -255,6 +263,7 @@ namespace MQTTnet.Extensions.ManagedClient
_messageQueue.Dispose();
_messageQueueLock.Dispose();
_mqttClient.Dispose();
_subscriptionsQueuedSignal.Dispose();
}

private void ThrowIfDisposed()
@@ -296,6 +305,12 @@ namespace MQTTnet.Extensions.ManagedClient

_logger.Info("Stopped");
}
_reconnectSubscriptions.Clear();
lock (_subscriptions)
{
_subscriptions.Clear();
_unsubscriptions.Clear();
}
}
}

@@ -311,16 +326,16 @@ namespace MQTTnet.Extensions.ManagedClient
return;
}

if (connectionState == ReconnectionResult.Reconnected || _subscriptionsNotPushed)
if (connectionState == ReconnectionResult.Reconnected)
{
await SynchronizeSubscriptionsAsync().ConfigureAwait(false);
await PublishReconnectSubscriptionsAsync().ConfigureAwait(false);
StartPublishing();
return;
}

if (connectionState == ReconnectionResult.StillConnected)
{
await Task.Delay(Options.ConnectionCheckInterval, cancellationToken).ConfigureAwait(false);
await PublishSubscriptionsAsync(Options.ConnectionCheckInterval, cancellationToken).ConfigureAwait(false);
}
}
catch (OperationCanceledException)
@@ -390,7 +405,7 @@ namespace MQTTnet.Extensions.ManagedClient
// it from the queue. If not, that means this.PublishAsync has already
// removed it, in which case we don't want to do anything.
_messageQueue.RemoveFirst(i => i.Id.Equals(message.Id));
if (_storageManager != null)
{
await _storageManager.RemoveAsync(message).ConfigureAwait(false);
@@ -415,7 +430,7 @@ namespace MQTTnet.Extensions.ManagedClient
using (await _messageQueueLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) //lock to avoid conflict with this.PublishAsync
{
_messageQueue.RemoveFirst(i => i.Id.Equals(message.Id));
if (_storageManager != null)
{
await _storageManager.RemoveAsync(message).ConfigureAwait(false);
@@ -439,50 +454,84 @@ namespace MQTTnet.Extensions.ManagedClient
}
}

private async Task SynchronizeSubscriptionsAsync()
private async Task PublishSubscriptionsAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
_logger.Info("Synchronizing subscriptions");
var endTime = DateTime.UtcNow + timeout;
while (await _subscriptionsQueuedSignal.WaitAsync(GetRemainingTime(endTime), cancellationToken).ConfigureAwait(false))
{
List<TopicFilter> subscriptions;
HashSet<string> unsubscriptions;

List<TopicFilter> subscriptions;
HashSet<string> unsubscriptions;
lock (_subscriptions)
{
subscriptions = _subscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }).ToList();
_subscriptions.Clear();
unsubscriptions = new HashSet<string>(_unsubscriptions);
_unsubscriptions.Clear();
}

lock (_subscriptions)
{
subscriptions = _subscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }).ToList();
if (!subscriptions.Any() && !unsubscriptions.Any())
{
continue;
}

unsubscriptions = new HashSet<string>(_unsubscriptions);
_unsubscriptions.Clear();
_logger.Info("Publishing subscriptions");

_subscriptionsNotPushed = false;
}
foreach (var unsubscription in unsubscriptions)
{
_reconnectSubscriptions.Remove(unsubscription);
}

if (!subscriptions.Any() && !unsubscriptions.Any())
{
return;
}
foreach (var subscription in subscriptions)
{
_reconnectSubscriptions[subscription.Topic] = subscription.QualityOfServiceLevel;
}

try
{
if (unsubscriptions.Any())
try
{
await _mqttClient.UnsubscribeAsync(unsubscriptions.ToArray()).ConfigureAwait(false);
if (unsubscriptions.Any())
{
await _mqttClient.UnsubscribeAsync(unsubscriptions.ToArray()).ConfigureAwait(false);
}

if (subscriptions.Any())
{
await _mqttClient.SubscribeAsync(subscriptions.ToArray()).ConfigureAwait(false);
}
}
catch (Exception exception)
{
await HandleSubscriptionExceptionAsync(exception).ConfigureAwait(false);
}
}
}

if (subscriptions.Any())
private async Task PublishReconnectSubscriptionsAsync()
{
_logger.Info("Publishing subscriptions at reconnect");

try
{
if (_reconnectSubscriptions.Any())
{
var subscriptions = _reconnectSubscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value });
await _mqttClient.SubscribeAsync(subscriptions.ToArray()).ConfigureAwait(false);
}
}
catch (Exception exception)
{
_logger.Warning(exception, "Synchronizing subscriptions failed.");
_subscriptionsNotPushed = true;
await HandleSubscriptionExceptionAsync(exception).ConfigureAwait(false);
}
}

var synchronizingSubscriptionsFailedHandler = SynchronizingSubscriptionsFailedHandler;
if (SynchronizingSubscriptionsFailedHandler != null)
{
await synchronizingSubscriptionsFailedHandler.HandleSynchronizingSubscriptionsFailedAsync(new ManagedProcessFailedEventArgs(exception)).ConfigureAwait(false);
}
private async Task HandleSubscriptionExceptionAsync(Exception exception)
{
_logger.Warning(exception, "Synchronizing subscriptions failed.");

var synchronizingSubscriptionsFailedHandler = SynchronizingSubscriptionsFailedHandler;
if (SynchronizingSubscriptionsFailedHandler != null)
{
await synchronizingSubscriptionsFailedHandler.HandleSynchronizingSubscriptionsFailedAsync(new ManagedProcessFailedEventArgs(exception)).ConfigureAwait(false);
}
}

@@ -509,7 +558,7 @@ namespace MQTTnet.Extensions.ManagedClient
return ReconnectionResult.NotConnected;
}
}
private void StartPublishing()
{
if (_publishingCancellationToken != null)
@@ -536,5 +585,11 @@ namespace MQTTnet.Extensions.ManagedClient
_connectionCancellationToken?.Dispose();
_connectionCancellationToken = null;
}

private TimeSpan GetRemainingTime(DateTime endTime)
{
var remainingTime = endTime - DateTime.UtcNow;
return remainingTime < TimeSpan.Zero ? TimeSpan.Zero : remainingTime;
}
}
}

+ 239
- 20
Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs 查看文件

@@ -1,10 +1,13 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Client;
using MQTTnet.Client.Connecting;
using MQTTnet.Client.Options;
using MQTTnet.Client.Receiving;
using MQTTnet.Diagnostics;
using MQTTnet.Extensions.ManagedClient;
using MQTTnet.Server;
@@ -95,21 +98,20 @@ namespace MQTTnet.Tests
var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost", testEnvironment.ServerPort);

TaskCompletionSource<bool> connected = new TaskCompletionSource<bool>();
managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => { connected.SetResult(true);});
var connected = GetConnectedTask(managedClient);

await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder()
.WithClientOptions(clientOptions)
.Build());

await connected.Task;
await connected;

await managedClient.StopAsync();

Assert.AreEqual(0, (await server.GetClientStatusAsync()).Count);
}
}
[TestMethod]
public async Task Storage_Queue_Drains()
{
@@ -127,12 +129,7 @@ namespace MQTTnet.Tests
.WithTcpServer("localhost", testEnvironment.ServerPort);
var storage = new ManagedMqttClientTestStorage();

TaskCompletionSource<bool> connected = new TaskCompletionSource<bool>();
managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e =>
{
managedClient.ConnectedHandler = null;
connected.SetResult(true);
});
var connected = GetConnectedTask(managedClient);

await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder()
.WithClientOptions(clientOptions)
@@ -140,7 +137,7 @@ namespace MQTTnet.Tests
.WithAutoReconnectDelay(System.TimeSpan.FromSeconds(5))
.Build());

await connected.Task;
await connected;

await testEnvironment.Server.StopAsync();

@@ -151,17 +148,12 @@ namespace MQTTnet.Tests
//in storage at this point (i.e. no waiting).
Assert.AreEqual(1, storage.GetMessageCount());

connected = new TaskCompletionSource<bool>();
managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e =>
{
managedClient.ConnectedHandler = null;
connected.SetResult(true);
});
connected = GetConnectedTask(managedClient);

await testEnvironment.Server.StartAsync(new MqttServerOptionsBuilder()
.WithDefaultEndpointPort(testEnvironment.ServerPort).Build());

await connected.Task;
await connected;

//Wait 500ms here so the client has time to publish the queued message
await Task.Delay(500);
@@ -171,8 +163,235 @@ namespace MQTTnet.Tests
await managedClient.StopAsync();
}
}

[TestMethod]
public async Task Subscriptions_And_Unsubscriptions_Are_Made_And_Reestablished_At_Reconnect()
{
using (var testEnvironment = new TestEnvironment())
{
var unmanagedClient = testEnvironment.CreateClient();
var managedClient = await CreateManagedClientAsync(testEnvironment, unmanagedClient);

var received = SetupReceivingOfMessages(managedClient, 2);

// Perform some opposing subscriptions and unsubscriptions to verify
// that these conflicting subscriptions are handled correctly
await managedClient.SubscribeAsync("keptSubscribed");
await managedClient.SubscribeAsync("subscribedThenUnsubscribed");

await managedClient.UnsubscribeAsync("subscribedThenUnsubscribed");
await managedClient.UnsubscribeAsync("unsubscribedThenSubscribed");

await managedClient.SubscribeAsync("unsubscribedThenSubscribed");

//wait a bit for the subscriptions to become established before the messages are published
await Task.Delay(500);

var sendingClient = await testEnvironment.ConnectClientAsync();

async Task PublishMessages()
{
await sendingClient.PublishAsync("keptSubscribed", new byte[] { 1 });
await sendingClient.PublishAsync("subscribedThenUnsubscribed", new byte[] { 1 });
await sendingClient.PublishAsync("unsubscribedThenSubscribed", new byte[] { 1 });
}

await PublishMessages();

async Task AssertMessagesReceived()
{
var messages = await received;
Assert.AreEqual("keptSubscribed", messages[0].Topic);
Assert.AreEqual("unsubscribedThenSubscribed", messages[1].Topic);
}

await AssertMessagesReceived();

var connected = GetConnectedTask(managedClient);

await unmanagedClient.DisconnectAsync();

// the managed client has to reconnect by itself
await connected;

// wait a bit so that the managed client can reestablish the subscriptions
await Task.Delay(500);

received = SetupReceivingOfMessages(managedClient, 2);

await PublishMessages();

// and then the same subscriptions need to exist again
await AssertMessagesReceived();
}
}

// This case also serves as a regression test for the previous behavior which re-published
// each and every existing subscriptions with every new subscription that was made
// (causing performance problems and having the visible symptom of retained messages being received again)
[TestMethod]
public async Task Subscriptions_Subscribe_Only_New_Subscriptions()
{
using (var testEnvironment = new TestEnvironment())
{
var managedClient = await CreateManagedClientAsync(testEnvironment);

var sendingClient = await testEnvironment.ConnectClientAsync();

await managedClient.SubscribeAsync("topic");

//wait a bit for the subscription to become established
await Task.Delay(500);

await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true });

var messages = await SetupReceivingOfMessages(managedClient, 1);

Assert.AreEqual(1, messages.Count);
Assert.AreEqual("topic", messages.Single().Topic);

await managedClient.SubscribeAsync("anotherTopic");

await Task.Delay(500);

// The subscription of the other topic must not trigger a re-subscription of the existing topic
// (and thus renewed receiving of the retained message)
Assert.AreEqual(1, messages.Count);
}
}

// This case also serves as a regression test for the previous behavior
// that subscriptions were only published at the ConnectionCheckInterval
[TestMethod]
public async Task Subscriptions_Are_Published_Immediately()
{
using (var testEnvironment = new TestEnvironment())
{
// Use a long connection check interval to verify that the subscriptions
// do not depend on the connection check interval anymore
var connectionCheckInterval = TimeSpan.FromSeconds(10);
var managedClient = await CreateManagedClientAsync(testEnvironment, null, connectionCheckInterval);
var sendingClient = await testEnvironment.ConnectClientAsync();
await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true });

await managedClient.SubscribeAsync("topic");

var subscribeTime = DateTime.UtcNow;

var messages = await SetupReceivingOfMessages(managedClient, 1);

var elapsed = DateTime.UtcNow - subscribeTime;
Assert.IsTrue(elapsed < TimeSpan.FromSeconds(1), $"Subscriptions must be activated immediately, this one took {elapsed}");
Assert.AreEqual(messages.Single().Topic, "topic");
}
}

[TestMethod]
public async Task Subscriptions_Are_Cleared_At_Logout()
{
using (var testEnvironment = new TestEnvironment())
{
var managedClient = await CreateManagedClientAsync(testEnvironment);

var sendingClient = await testEnvironment.ConnectClientAsync();
await sendingClient.PublishAsync(new MqttApplicationMessage
{ Topic = "topic", Payload = new byte[] { 1 }, Retain = true });

// Wait a bit for the retained message to be available
await Task.Delay(500);

await managedClient.SubscribeAsync("topic");

await SetupReceivingOfMessages(managedClient, 1);

await managedClient.StopAsync();

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost", testEnvironment.ServerPort);
await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder()
.WithClientOptions(clientOptions)
.WithAutoReconnectDelay(TimeSpan.FromSeconds(1))
.Build());

var messages = new List<MqttApplicationMessage>();
managedClient.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(r =>
{
messages.Add(r.ApplicationMessage);
});

await Task.Delay(500);

// After reconnect and then some delay, the retained message must not be received,
// showing that the subscriptions were cleared
Assert.AreEqual(0, messages.Count);
}
}

private async Task<ManagedMqttClient> CreateManagedClientAsync(
TestEnvironment testEnvironment,
IMqttClient underlyingClient = null,
TimeSpan? connectionCheckInterval = null)
{
await testEnvironment.StartServerAsync();

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost", testEnvironment.ServerPort);

var managedOptions = new ManagedMqttClientOptionsBuilder()
.WithClientOptions(clientOptions)
.Build();

// Use a short connection check interval so that subscription operations are performed quickly
// in order to verify against a previous implementation that performed subscriptions only
// at connection check intervals
managedOptions.ConnectionCheckInterval = connectionCheckInterval ?? TimeSpan.FromSeconds(0.1);

var managedClient =
new ManagedMqttClient(underlyingClient ?? testEnvironment.CreateClient(), new MqttNetLogger().CreateChildLogger());

var connected = GetConnectedTask(managedClient);

await managedClient.StartAsync(managedOptions);

await connected;

return managedClient;
}

/// <summary>
/// Returns a task that will finish when the <paramref name="managedClient"/> has connected
/// </summary>
private Task GetConnectedTask(ManagedMqttClient managedClient)
{
TaskCompletionSource<bool> connected = new TaskCompletionSource<bool>();
managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e =>
{
managedClient.ConnectedHandler = null;
connected.SetResult(true);
});
return connected.Task;
}

/// <summary>
/// Returns a task that will return the messages received on <paramref name="managedClient"/>
/// when <paramref name="expectedNumberOfMessages"/> have been received
/// </summary>
private Task<List<MqttApplicationMessage>> SetupReceivingOfMessages(ManagedMqttClient managedClient, int expectedNumberOfMessages)
{
var receivedMessages = new List<MqttApplicationMessage>();
var allReceived = new TaskCompletionSource<List<MqttApplicationMessage>>();
managedClient.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(r =>
{
receivedMessages.Add(r.ApplicationMessage);
if (receivedMessages.Count == expectedNumberOfMessages)
{
allReceived.SetResult(receivedMessages);
}
});
return allReceived.Task;
}
}
public class ManagedMqttClientTestStorage : IManagedMqttClientStorage
{
private IList<ManagedMqttApplicationMessage> _messages = null;


正在加载...
取消
保存