@@ -11,12 +11,13 @@ namespace MQTTnet.Server | |||||
readonly IMqttNetLogger _logger; | readonly IMqttNetLogger _logger; | ||||
readonly DateTime _createdTimestamp = DateTime.UtcNow; | readonly DateTime _createdTimestamp = DateTime.UtcNow; | ||||
readonly IMqttRetainedMessagesManager _retainedMessagesManager; | |||||
public MqttClientSession(string clientId, IDictionary<object, object> items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetLogger logger) | |||||
public MqttClientSession(string clientId, IDictionary<object, object> items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttRetainedMessagesManager retainedMessagesManager, IMqttNetLogger logger) | |||||
{ | { | ||||
ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); | ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); | ||||
Items = items ?? throw new ArgumentNullException(nameof(items)); | Items = items ?? throw new ArgumentNullException(nameof(items)); | ||||
_retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); | |||||
SubscriptionsManager = new MqttClientSubscriptionsManager(this, eventDispatcher, serverOptions); | SubscriptionsManager = new MqttClientSubscriptionsManager(this, eventDispatcher, serverOptions); | ||||
ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); | ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); | ||||
@@ -52,11 +53,13 @@ namespace MQTTnet.Server | |||||
ApplicationMessagesQueue.Enqueue(applicationMessage, senderClientId, checkSubscriptionsResult.QualityOfServiceLevel, isRetainedApplicationMessage); | ApplicationMessagesQueue.Enqueue(applicationMessage, senderClientId, checkSubscriptionsResult.QualityOfServiceLevel, isRetainedApplicationMessage); | ||||
} | } | ||||
public async Task SubscribeAsync(ICollection<TopicFilter> topicFilters, IMqttRetainedMessagesManager retainedMessagesManager) | |||||
public async Task SubscribeAsync(ICollection<TopicFilter> topicFilters) | |||||
{ | { | ||||
if (topicFilters is null) throw new ArgumentNullException(nameof(topicFilters)); | |||||
await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); | await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); | ||||
var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); | |||||
var matchingRetainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); | |||||
foreach (var matchingRetainedMessage in matchingRetainedMessages) | foreach (var matchingRetainedMessage in matchingRetainedMessages) | ||||
{ | { | ||||
EnqueueApplicationMessage(matchingRetainedMessage, null, true); | EnqueueApplicationMessage(matchingRetainedMessage, null, true); | ||||
@@ -65,6 +68,8 @@ namespace MQTTnet.Server | |||||
public Task UnsubscribeAsync(IEnumerable<string> topicFilters) | public Task UnsubscribeAsync(IEnumerable<string> topicFilters) | ||||
{ | { | ||||
if (topicFilters is null) throw new ArgumentNullException(nameof(topicFilters)); | |||||
return SubscriptionsManager.UnsubscribeAsync(topicFilters); | return SubscriptionsManager.UnsubscribeAsync(topicFilters); | ||||
} | } | ||||
@@ -118,7 +118,7 @@ namespace MQTTnet.Server | |||||
throw new InvalidOperationException($"Client session '{clientId}' is unknown."); | throw new InvalidOperationException($"Client session '{clientId}' is unknown."); | ||||
} | } | ||||
return session.SubscribeAsync(topicFilters, _retainedMessagesManager); | |||||
return session.SubscribeAsync(topicFilters); | |||||
} | } | ||||
public Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters) | public Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters) | ||||
@@ -280,7 +280,7 @@ namespace MQTTnet.Server | |||||
connectPacket, | connectPacket, | ||||
connectionValidatorContext, | connectionValidatorContext, | ||||
channelAdapter, | channelAdapter, | ||||
async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false), | |||||
async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false), | |||||
async disconnectType => await CleanUpClient(clientId, channelAdapter, disconnectType) | async disconnectType => await CleanUpClient(clientId, channelAdapter, disconnectType) | ||||
).ConfigureAwait(false); | ).ConfigureAwait(false); | ||||
@@ -371,7 +371,7 @@ namespace MQTTnet.Server | |||||
if (session == null) | if (session == null) | ||||
{ | { | ||||
session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _logger); | |||||
session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _logger); | |||||
_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | ||||
} | } | ||||
@@ -1,10 +1,10 @@ | |||||
using System.Collections.Concurrent; | |||||
using System.Threading.Tasks; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using MQTTnet.Packets; | using MQTTnet.Packets; | ||||
using MQTTnet.Protocol; | using MQTTnet.Protocol; | ||||
using MQTTnet.Server; | using MQTTnet.Server; | ||||
using MQTTnet.Tests.Mockups; | using MQTTnet.Tests.Mockups; | ||||
using System.Collections.Concurrent; | |||||
using System.Threading.Tasks; | |||||
namespace MQTTnet.Tests | namespace MQTTnet.Tests | ||||
{ | { | ||||
@@ -14,8 +14,7 @@ namespace MQTTnet.Tests | |||||
[TestMethod] | [TestMethod] | ||||
public async Task MqttSubscriptionsManager_SubscribeSingleSuccess() | public async Task MqttSubscriptionsManager_SubscribeSingleSuccess() | ||||
{ | { | ||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var s = CreateSession(); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | ||||
@@ -32,8 +31,7 @@ namespace MQTTnet.Tests | |||||
[TestMethod] | [TestMethod] | ||||
public async Task MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() | public async Task MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() | ||||
{ | { | ||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var s = CreateSession(); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | ||||
@@ -50,8 +48,7 @@ namespace MQTTnet.Tests | |||||
[TestMethod] | [TestMethod] | ||||
public async Task MqttSubscriptionsManager_SubscribeTwoTimesSuccess() | public async Task MqttSubscriptionsManager_SubscribeTwoTimesSuccess() | ||||
{ | { | ||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var s = CreateSession(); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | ||||
@@ -69,8 +66,7 @@ namespace MQTTnet.Tests | |||||
[TestMethod] | [TestMethod] | ||||
public async Task MqttSubscriptionsManager_SubscribeSingleNoSuccess() | public async Task MqttSubscriptionsManager_SubscribeSingleNoSuccess() | ||||
{ | { | ||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var s = CreateSession(); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | ||||
@@ -85,8 +81,7 @@ namespace MQTTnet.Tests | |||||
[TestMethod] | [TestMethod] | ||||
public async Task MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() | public async Task MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() | ||||
{ | { | ||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var s = CreateSession(); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | ||||
@@ -103,5 +98,16 @@ namespace MQTTnet.Tests | |||||
Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | ||||
} | } | ||||
MqttClientSession CreateSession() | |||||
{ | |||||
return new MqttClientSession( | |||||
"", | |||||
new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), | |||||
new MqttServerOptions(), | |||||
new MqttRetainedMessagesManager(), | |||||
new TestLogger()); | |||||
} | |||||
} | } | ||||
} | } |