diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index b10019f..a360e21 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -11,6 +11,8 @@ false MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker) and supports v3.1.0, v3.1.1 and v5.0.0 of the MQTT protocol. +* [Server] Added items dictionary to client session in order to share data across interceptors as along as the session exists. +* [Server] Exposed CONNECT packet properties in Application Message and Subscription interceptor. * [MQTTnet.Server] Added REST API for publishing basic messages. Copyright Christian Kratky 2016-2019 diff --git a/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs b/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs index c5f3afd..8d378af 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs @@ -24,12 +24,18 @@ namespace MQTTnet.Server.Mqtt { var pythonContext = new PythonDictionary { + { "client_id", context.ClientId }, + { "retain", context.ApplicationMessage.Retain }, + { "username", context.Username }, + { "password", context.Password }, + { "raw_password", new Bytes(context.RawPassword ?? new byte[0]) }, + { "clean_session", context.CleanSession}, + { "authentication_method", context.AuthenticationMethod}, + { "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) }, { "accept_publish", context.AcceptPublish }, { "close_connection", context.CloseConnection }, - { "client_id", context.ClientId }, { "topic", context.ApplicationMessage.Topic }, - { "qos", (int)context.ApplicationMessage.QualityOfServiceLevel }, - { "retain", context.ApplicationMessage.Retain } + { "qos", (int)context.ApplicationMessage.QualityOfServiceLevel } }; _pythonScriptHostService.InvokeOptionalFunction("on_intercept_application_message", pythonContext); diff --git a/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs b/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs index 5612601..c33580e 100644 --- a/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs +++ b/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs @@ -1,8 +1,11 @@ -namespace MQTTnet.Server +using System.Collections.Generic; +using MQTTnet.Packets; + +namespace MQTTnet.Server { - public class MqttApplicationMessageInterceptorContext + public class MqttApplicationMessageInterceptorContext : MqttBaseInterceptorContext { - public MqttApplicationMessageInterceptorContext(string clientId, MqttApplicationMessage applicationMessage) + public MqttApplicationMessageInterceptorContext(string clientId, IDictionary sessionItems, MqttConnectPacket connectPacket, MqttApplicationMessage applicationMessage) : base(connectPacket, sessionItems) { ClientId = clientId; ApplicationMessage = applicationMessage; diff --git a/Source/MQTTnet/Server/MqttBaseInterceptorContext.cs b/Source/MQTTnet/Server/MqttBaseInterceptorContext.cs new file mode 100644 index 0000000..6909d5e --- /dev/null +++ b/Source/MQTTnet/Server/MqttBaseInterceptorContext.cs @@ -0,0 +1,54 @@ +using System.Collections.Generic; +using System.Text; +using MQTTnet.Packets; + +namespace MQTTnet.Server +{ + public class MqttBaseInterceptorContext + { + private readonly MqttConnectPacket _connectPacket; + + protected MqttBaseInterceptorContext(MqttConnectPacket connectPacket, IDictionary sessionItems) + { + _connectPacket = connectPacket; + SessionItems = sessionItems; + } + + public string Username => _connectPacket?.Username; + + public byte[] RawPassword => _connectPacket?.Password; + + public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); + + public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage; + + public bool? CleanSession => _connectPacket?.CleanSession; + + public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod; + + public List UserProperties => _connectPacket?.Properties?.UserProperties; + + public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData; + + public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod; + + public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize; + + public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum; + + public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum; + + public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation; + + public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation; + + public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval; + + public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval; + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index ed378bb..a34f28b 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -31,7 +31,6 @@ namespace MQTTnet.Server private readonly IMqttChannelAdapter _channelAdapter; private readonly IMqttDataConverter _dataConverter; private readonly string _endpoint; - private readonly MqttConnectPacket _connectPacket; private readonly DateTime _connectedTimestamp; private Task _packageReceiverTask; @@ -60,22 +59,24 @@ namespace MQTTnet.Server _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; _endpoint = _channelAdapter.Endpoint; - _connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); + ConnectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); if (logger == null) throw new ArgumentNullException(nameof(logger)); _logger = logger.CreateChildLogger(nameof(MqttClientConnection)); - _keepAliveMonitor = new MqttClientKeepAliveMonitor(_connectPacket.ClientId, StopAsync, _logger); + _keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, StopAsync, _logger); _connectedTimestamp = DateTime.UtcNow; _lastPacketReceivedTimestamp = _connectedTimestamp; _lastNonKeepAlivePacketReceivedTimestamp = _lastPacketReceivedTimestamp; } - public string ClientId => _connectPacket.ClientId; + public MqttConnectPacket ConnectPacket { get; } - public MqttClientSession Session { get; } + public string ClientId => ConnectPacket.ClientId; + public MqttClientSession Session { get; } + public async Task StopAsync() { StopInternal(); @@ -133,12 +134,12 @@ namespace MQTTnet.Server _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; _channelAdapter.ReadingPacketCompletedCallback = OnAdapterReadingPacketCompleted; - Session.WillMessage = _connectPacket.WillMessage; + Session.WillMessage = ConnectPacket.WillMessage; Task.Run(() => SendPendingPacketsAsync(_cancellationToken.Token), _cancellationToken.Token).Forget(_logger); // TODO: Change to single thread in SessionManager. Or use SessionManager and stats from KeepAliveMonitor. - _keepAliveMonitor.Start(_connectPacket.KeepAlivePeriod, _cancellationToken.Token); + _keepAliveMonitor.Start(ConnectPacket.KeepAlivePeriod, _cancellationToken.Token); await SendAsync( new MqttConnAckPacket @@ -271,7 +272,7 @@ namespace MQTTnet.Server private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) { // TODO: Let the channel adapter create the packet. - var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); + var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); await SendAsync(subscribeResult.ResponsePacket).ConfigureAwait(false); diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 804a223..73263cb 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -12,11 +12,12 @@ namespace MQTTnet.Server private readonly DateTime _createdTimestamp = DateTime.UtcNow; - public MqttClientSession(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger) + public MqttClientSession(string clientId, IDictionary items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger) { ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + Items = items ?? throw new ArgumentNullException(nameof(items)); - SubscriptionsManager = new MqttClientSubscriptionsManager(clientId, eventDispatcher, serverOptions); + SubscriptionsManager = new MqttClientSubscriptionsManager(this, eventDispatcher, serverOptions); ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); if (logger == null) throw new ArgumentNullException(nameof(logger)); @@ -33,6 +34,11 @@ namespace MQTTnet.Server public MqttClientSessionApplicationMessagesQueue ApplicationMessagesQueue { get; } + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary Items { get; } + public void EnqueueApplicationMessage(MqttApplicationMessage applicationMessage, string senderClientId, bool isRetainedApplicationMessage) { var checkSubscriptionsResult = SubscriptionsManager.CheckSubscriptions(applicationMessage.Topic, applicationMessage.QualityOfServiceLevel); @@ -48,7 +54,7 @@ namespace MQTTnet.Server public async Task SubscribeAsync(ICollection topicFilters, MqttRetainedMessagesManager retainedMessagesManager) { - await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); + await SubscriptionsManager.SubscribeAsync(topicFilters, null).ConfigureAwait(false); var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); foreach (var matchingRetainedMessage in matchingRetainedMessages) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index c2d5637..1463ed4 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -20,7 +20,8 @@ namespace MQTTnet.Server private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1); private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); - + private readonly IDictionary _serverSessionItems = new ConcurrentDictionary(); + private readonly CancellationToken _cancellationToken; private readonly MqttServerEventDispatcher _eventDispatcher; @@ -241,19 +242,19 @@ namespace MQTTnet.Server clientId = connectPacket.ClientId; - var validatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); + var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); - if (validatorContext.ReasonCode != MqttConnectReasonCode.Success) + if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { // Send failure response here without preparing a session. The result for a successful connect // will be sent from the session itself. - var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(validatorContext); + var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); await channelAdapter.SendPacketAsync(connAckPacket, _options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); return; } - var connection = await CreateConnectionAsync(channelAdapter, connectPacket).ConfigureAwait(false); + var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); @@ -302,8 +303,7 @@ namespace MQTTnet.Server await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false); // Check the client ID and set a random one if supported. - if (string.IsNullOrEmpty(connectPacket.ClientId) && - channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500) + if (string.IsNullOrEmpty(connectPacket.ClientId) && channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500) { connectPacket.ClientId = context.AssignedClientIdentifier; } @@ -316,7 +316,7 @@ namespace MQTTnet.Server return context; } - private async Task CreateConnectionAsync(IMqttChannelAdapter channelAdapter, MqttConnectPacket connectPacket) + private async Task CreateConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) { await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false); try @@ -345,7 +345,7 @@ namespace MQTTnet.Server if (session == null) { - session = new MqttClientSession(connectPacket.ClientId, _eventDispatcher, _options, _logger); + session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _logger); _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } @@ -362,7 +362,7 @@ namespace MQTTnet.Server } } - private async Task InterceptApplicationMessageAsync(MqttClientConnection sender, MqttApplicationMessage applicationMessage) + private async Task InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage) { var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) @@ -370,13 +370,25 @@ namespace MQTTnet.Server return null; } - var senderClientId = sender?.ClientId; - if (sender == null) + string senderClientId; + IDictionary sessionItems; + MqttConnectPacket connectPacket; + + var messageIsFromServer = senderConnection == null; + if (messageIsFromServer) { senderClientId = _options.ClientId; + sessionItems = _serverSessionItems; + connectPacket = null; + } + else + { + senderClientId = senderConnection.ClientId; + sessionItems = senderConnection.Session.Items; + connectPacket = senderConnection.ConnectPacket; } - var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, applicationMessage); + var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, connectPacket, applicationMessage); await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); return interceptorContext; } diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index e2024a6..04d7495 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -10,20 +10,23 @@ namespace MQTTnet.Server public class MqttClientSubscriptionsManager { private readonly Dictionary _subscriptions = new Dictionary(); - private readonly IMqttServerOptions _options; + private readonly MqttClientSession _clientSession; + private readonly IMqttServerOptions _serverOptions; private readonly MqttServerEventDispatcher _eventDispatcher; - private readonly string _clientId; - public MqttClientSubscriptionsManager(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions options) + public MqttClientSubscriptionsManager(MqttClientSession clientSession, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions) { - _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); - _options = options ?? throw new ArgumentNullException(nameof(options)); + _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); + + // TODO: Consider removing the server options here and build a new class "ISubscriptionInterceptor" and just pass it. The instance is generated in the root server class upon start. + _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher)); } - public async Task SubscribeAsync(MqttSubscribePacket subscribePacket) + public async Task SubscribeAsync(MqttSubscribePacket subscribePacket, MqttConnectPacket connectPacket) { if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); + if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); var result = new MqttClientSubscribeResult { @@ -37,7 +40,7 @@ namespace MQTTnet.Server foreach (var originalTopicFilter in subscribePacket.TopicFilters) { - var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter).ConfigureAwait(false); + var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter, connectPacket).ConfigureAwait(false); var finalTopicFilter = interceptorContext.TopicFilter; @@ -64,18 +67,20 @@ namespace MQTTnet.Server _subscriptions[finalTopicFilter.Topic] = finalTopicFilter.QualityOfServiceLevel; } - await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, finalTopicFilter).ConfigureAwait(false); + await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false); } } return result; } - public async Task SubscribeAsync(IEnumerable topicFilters) + public async Task SubscribeAsync(IEnumerable topicFilters, MqttConnectPacket connectPacket) { + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + foreach (var topicFilter in topicFilters) { - var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); + var interceptorContext = await InterceptSubscribeAsync(topicFilter, connectPacket).ConfigureAwait(false); if (!interceptorContext.AcceptSubscription) { continue; @@ -88,7 +93,7 @@ namespace MQTTnet.Server _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; } - await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false); + await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); } } } @@ -119,7 +124,7 @@ namespace MQTTnet.Server foreach (var topicFilter in unsubscribePacket.TopicFilters) { - await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false); + await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); } return unsubAckPacket; @@ -190,12 +195,12 @@ namespace MQTTnet.Server } } - private async Task InterceptSubscribeAsync(TopicFilter topicFilter) + private async Task InterceptSubscribeAsync(TopicFilter topicFilter, MqttConnectPacket connectPacket) { - var context = new MqttSubscriptionInterceptorContext(_clientId, topicFilter); - if (_options.SubscriptionInterceptor != null) + var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, connectPacket, _clientSession.Items); + if (_serverOptions.SubscriptionInterceptor != null) { - await _options.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); + await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); } return context; diff --git a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs index 45dba13..2ab3383 100644 --- a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs +++ b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs @@ -1,7 +1,7 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Security.Cryptography.X509Certificates; -using System.Text; using MQTTnet.Adapter; using MQTTnet.Formatter; using MQTTnet.Packets; @@ -9,51 +9,19 @@ using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttConnectionValidatorContext + public class MqttConnectionValidatorContext : MqttBaseInterceptorContext { private readonly MqttConnectPacket _connectPacket; private readonly IMqttChannelAdapter _clientAdapter; - public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) + public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) : base(connectPacket, new ConcurrentDictionary()) { - _connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); + _connectPacket = connectPacket; _clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter)); } public string ClientId => _connectPacket.ClientId; - public string Username => _connectPacket.Username; - - public byte[] RawPassword => _connectPacket.Password; - - public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); - - public MqttApplicationMessage WillMessage => _connectPacket.WillMessage; - - public bool CleanSession => _connectPacket.CleanSession; - - public ushort KeepAlivePeriod => _connectPacket.KeepAlivePeriod; - - public List UserProperties => _connectPacket.Properties?.UserProperties; - - public byte[] AuthenticationData => _connectPacket.Properties?.AuthenticationData; - - public string AuthenticationMethod => _connectPacket.Properties?.AuthenticationMethod; - - public uint? MaximumPacketSize => _connectPacket.Properties?.MaximumPacketSize; - - public ushort? ReceiveMaximum => _connectPacket.Properties?.ReceiveMaximum; - - public ushort? TopicAliasMaximum => _connectPacket.Properties?.TopicAliasMaximum; - - public bool? RequestProblemInformation => _connectPacket.Properties?.RequestProblemInformation; - - public bool? RequestResponseInformation => _connectPacket.Properties?.RequestResponseInformation; - - public uint? SessionExpiryInterval => _connectPacket.Properties?.SessionExpiryInterval; - - public uint? WillDelayInterval => _connectPacket.Properties?.WillDelayInterval; - public string Endpoint => _clientAdapter.Endpoint; public bool IsSecureConnection => _clientAdapter.IsSecureConnection; @@ -61,7 +29,7 @@ namespace MQTTnet.Server public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate; public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion; - + /// /// This is used for MQTTv3 only. /// diff --git a/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs b/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs index ca98c95..74799b2 100644 --- a/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs +++ b/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs @@ -1,10 +1,12 @@ using System; +using System.Collections.Generic; +using MQTTnet.Packets; namespace MQTTnet.Server { - public class MqttSubscriptionInterceptorContext + public class MqttSubscriptionInterceptorContext : MqttBaseInterceptorContext { - public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter) + public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, MqttConnectPacket connectPacket, IDictionary sessionItems) : base(connectPacket, sessionItems) { ClientId = clientId; TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter)); @@ -13,7 +15,7 @@ namespace MQTTnet.Server public string ClientId { get; } public TopicFilter TopicFilter { get; set; } - + public bool AcceptSubscription { get; set; } = true; public bool CloseConnection { get; set; } diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs index 1c4ec84..6f0d542 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs @@ -1,4 +1,6 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Collections.Concurrent; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Server; @@ -10,14 +12,17 @@ namespace MQTTnet.Tests public class MqttSubscriptionsManager_Tests { [TestMethod] - public void MqttSubscriptionsManager_SubscribeSingleSuccess() + public async Task MqttSubscriptionsManager_SubscribeSingleSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce); Assert.IsTrue(result.IsSubscribed); @@ -25,14 +30,17 @@ namespace MQTTnet.Tests } [TestMethod] - public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() + public async Task MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); @@ -40,15 +48,18 @@ namespace MQTTnet.Tests } [TestMethod] - public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess() + public async Task MqttSubscriptionsManager_SubscribeTwoTimesSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); @@ -56,33 +67,39 @@ namespace MQTTnet.Tests } [TestMethod] - public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() + public async Task MqttSubscriptionsManager_SubscribeSingleNoSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } [TestMethod] - public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() + public async Task MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); var up = new MqttUnsubscribePacket(); up.TopicFilters.Add("A/B/C"); - sm.UnsubscribeAsync(up); + await sm.UnsubscribeAsync(up); Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); }