diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 70b04d9..f318b85 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -11,12 +11,13 @@ namespace MQTTnet.Server readonly IMqttNetLogger _logger; readonly DateTime _createdTimestamp = DateTime.UtcNow; + readonly IMqttRetainedMessagesManager _retainedMessagesManager; - public MqttClientSession(string clientId, IDictionary items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetLogger logger) + public MqttClientSession(string clientId, IDictionary items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttRetainedMessagesManager retainedMessagesManager, IMqttNetLogger logger) { ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); Items = items ?? throw new ArgumentNullException(nameof(items)); - + _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); SubscriptionsManager = new MqttClientSubscriptionsManager(this, eventDispatcher, serverOptions); ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); @@ -52,11 +53,13 @@ namespace MQTTnet.Server ApplicationMessagesQueue.Enqueue(applicationMessage, senderClientId, checkSubscriptionsResult.QualityOfServiceLevel, isRetainedApplicationMessage); } - public async Task SubscribeAsync(ICollection topicFilters, IMqttRetainedMessagesManager retainedMessagesManager) + public async Task SubscribeAsync(ICollection topicFilters) { + if (topicFilters is null) throw new ArgumentNullException(nameof(topicFilters)); + await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); - var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); + var matchingRetainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); foreach (var matchingRetainedMessage in matchingRetainedMessages) { EnqueueApplicationMessage(matchingRetainedMessage, null, true); @@ -65,6 +68,8 @@ namespace MQTTnet.Server public Task UnsubscribeAsync(IEnumerable topicFilters) { + if (topicFilters is null) throw new ArgumentNullException(nameof(topicFilters)); + return SubscriptionsManager.UnsubscribeAsync(topicFilters); } diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index f98eb05..b10f8f7 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -118,7 +118,7 @@ namespace MQTTnet.Server throw new InvalidOperationException($"Client session '{clientId}' is unknown."); } - return session.SubscribeAsync(topicFilters, _retainedMessagesManager); + return session.SubscribeAsync(topicFilters); } public Task UnsubscribeAsync(string clientId, IEnumerable topicFilters) @@ -280,7 +280,7 @@ namespace MQTTnet.Server connectPacket, connectionValidatorContext, channelAdapter, - async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false), + async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false), async disconnectType => await CleanUpClient(clientId, channelAdapter, disconnectType) ).ConfigureAwait(false); @@ -371,7 +371,7 @@ namespace MQTTnet.Server if (session == null) { - session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _logger); + session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _logger); _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs index 6f0d542..3c7724d 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs @@ -1,10 +1,10 @@ -using System.Collections.Concurrent; -using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Server; using MQTTnet.Tests.Mockups; +using System.Collections.Concurrent; +using System.Threading.Tasks; namespace MQTTnet.Tests { @@ -14,8 +14,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task MqttSubscriptionsManager_SubscribeSingleSuccess() { - var s = new MqttClientSession("", new ConcurrentDictionary(), - new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + var s = CreateSession(); var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); @@ -32,8 +31,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() { - var s = new MqttClientSession("", new ConcurrentDictionary(), - new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + var s = CreateSession(); var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); @@ -50,8 +48,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task MqttSubscriptionsManager_SubscribeTwoTimesSuccess() { - var s = new MqttClientSession("", new ConcurrentDictionary(), - new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + var s = CreateSession(); var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); @@ -69,8 +66,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task MqttSubscriptionsManager_SubscribeSingleNoSuccess() { - var s = new MqttClientSession("", new ConcurrentDictionary(), - new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + var s = CreateSession(); var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); @@ -85,8 +81,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() { - var s = new MqttClientSession("", new ConcurrentDictionary(), - new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + var s = CreateSession(); var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); @@ -103,5 +98,16 @@ namespace MQTTnet.Tests Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } + + MqttClientSession CreateSession() + { + return new MqttClientSession( + "", + new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), + new MqttServerOptions(), + new MqttRetainedMessagesManager(), + new TestLogger()); + } } }