From ba16ae6568f64a56495e04c7798fe24b687a8303 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Sat, 28 Oct 2017 10:34:25 +0200 Subject: [PATCH] Add interceptor for client subscriptions --- MQTTnet.Core/Client/MqttClient.cs | 59 ++++++++++--------- ...qttApplicationMessageInterceptorContext.cs | 7 +++ MQTTnet.Core/Server/MqttClientSession.cs | 42 ++++++++----- .../Server/MqttClientSubscribeResult.cs | 11 ++++ .../Server/MqttClientSubscriptionsManager.cs | 30 +++++++--- MQTTnet.Core/Server/MqttServer.cs | 10 +++- MQTTnet.Core/Server/MqttServerOptions.cs | 4 +- .../MqttSubscriptionInterceptorContext.cs | 19 ++++++ Tests/MQTTnet.Core.Tests/MqttServerTests.cs | 5 +- .../MqttSubscriptionsManagerTests.cs | 9 +-- Tests/MQTTnet.TestApp.NetCore/ServerTest.cs | 10 ++-- 11 files changed, 140 insertions(+), 66 deletions(-) create mode 100644 MQTTnet.Core/Server/MqttApplicationMessageInterceptorContext.cs create mode 100644 MQTTnet.Core/Server/MqttClientSubscribeResult.cs create mode 100644 MQTTnet.Core/Server/MqttSubscriptionInterceptorContext.cs diff --git a/MQTTnet.Core/Client/MqttClient.cs b/MQTTnet.Core/Client/MqttClient.cs index b54288c..914986c 100644 --- a/MQTTnet.Core/Client/MqttClient.cs +++ b/MQTTnet.Core/Client/MqttClient.cs @@ -149,36 +149,36 @@ namespace MQTTnet.Core.Client switch (qosGroup.Key) { case MqttQualityOfServiceLevel.AtMostOnce: - { - // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] - await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, qosPackets); - break; - } + { + // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] + await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, qosPackets); + break; + } case MqttQualityOfServiceLevel.AtLeastOnce: + { + foreach (var publishPacket in qosPackets) { - foreach (var publishPacket in qosPackets) - { - publishPacket.PacketIdentifier = GetNewPacketIdentifier(); - await SendAndReceiveAsync(publishPacket); - } - - break; + publishPacket.PacketIdentifier = GetNewPacketIdentifier(); + await SendAndReceiveAsync(publishPacket); } + + break; + } case MqttQualityOfServiceLevel.ExactlyOnce: + { + foreach (var publishPacket in qosPackets) { - foreach (var publishPacket in qosPackets) - { - publishPacket.PacketIdentifier = GetNewPacketIdentifier(); - var pubRecPacket = await SendAndReceiveAsync(publishPacket).ConfigureAwait(false); - await SendAndReceiveAsync(pubRecPacket.CreateResponse()).ConfigureAwait(false); - } - - break; + publishPacket.PacketIdentifier = GetNewPacketIdentifier(); + var pubRecPacket = await SendAndReceiveAsync(publishPacket).ConfigureAwait(false); + await SendAndReceiveAsync(pubRecPacket.CreateResponse()).ConfigureAwait(false); } + + break; + } default: - { - throw new InvalidOperationException(); - } + { + throw new InvalidOperationException(); + } } } } @@ -191,7 +191,7 @@ namespace MQTTnet.Core.Client Username = _options.Credentials?.Username, Password = _options.Credentials?.Password, CleanSession = _options.CleanSession, - KeepAlivePeriod = (ushort)_options.KeepAlivePeriod.TotalSeconds, + KeepAlivePeriod = (ushort) _options.KeepAlivePeriod.TotalSeconds, WillMessage = willApplicationMessage }; @@ -324,7 +324,7 @@ namespace MQTTnet.Core.Client if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) { FireApplicationMessageReceivedEvent(publishPacket); - await SendAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + await SendAsync(new MqttPubAckPacket {PacketIdentifier = publishPacket.PacketIdentifier}); return; } @@ -337,7 +337,7 @@ namespace MQTTnet.Core.Client } FireApplicationMessageReceivedEvent(publishPacket); - await SendAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + await SendAsync(new MqttPubRecPacket {PacketIdentifier = publishPacket.PacketIdentifier}); return; } @@ -363,12 +363,12 @@ namespace MQTTnet.Core.Client { var packetAwaiter = _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.CommunicationTimeout); await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, requestPacket).ConfigureAwait(false); - return (TResponsePacket)await packetAwaiter.ConfigureAwait(false); + return (TResponsePacket) await packetAwaiter.ConfigureAwait(false); } private ushort GetNewPacketIdentifier() { - return (ushort)Interlocked.Increment(ref _latestPacketIdentifier); + return (ushort) Interlocked.Increment(ref _latestPacketIdentifier); } private async Task SendKeepAliveMessagesAsync(CancellationToken cancellationToken) @@ -465,7 +465,8 @@ namespace MQTTnet.Core.Client private void StartSendKeepAliveMessages(CancellationToken cancellationToken) { #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Factory.StartNew(() => SendKeepAliveMessagesAsync(cancellationToken), cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default).ConfigureAwait(false); + Task.Factory.StartNew(() => SendKeepAliveMessagesAsync(cancellationToken), cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default) + .ConfigureAwait(false); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } } diff --git a/MQTTnet.Core/Server/MqttApplicationMessageInterceptorContext.cs b/MQTTnet.Core/Server/MqttApplicationMessageInterceptorContext.cs new file mode 100644 index 0000000..218a02d --- /dev/null +++ b/MQTTnet.Core/Server/MqttApplicationMessageInterceptorContext.cs @@ -0,0 +1,7 @@ +namespace MQTTnet.Core.Server +{ + public class MqttApplicationMessageInterceptorContext + { + public MqttApplicationMessage ApplicationMessage { get; set; } + } +} diff --git a/MQTTnet.Core/Server/MqttClientSession.cs b/MQTTnet.Core/Server/MqttClientSession.cs index 74d93f4..e1f8457 100644 --- a/MQTTnet.Core/Server/MqttClientSession.cs +++ b/MQTTnet.Core/Server/MqttClientSession.cs @@ -124,7 +124,7 @@ namespace MQTTnet.Core.Server while (!cancellationToken.IsCancellationRequested) { var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); - await ProcessReceivedPacketAsync(adapter, packet).ConfigureAwait(false); + await ProcessReceivedPacketAsync(adapter, packet, cancellationToken).ConfigureAwait(false); } } catch (OperationCanceledException) @@ -142,28 +142,35 @@ namespace MQTTnet.Core.Server } } - private async Task ProcessReceivedPacketAsync(IMqttCommunicationAdapter adapter, MqttBasePacket packet) + private async Task ProcessReceivedPacketAsync(IMqttCommunicationAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken) { if (packet is MqttSubscribePacket subscribePacket) { - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Subscribe(subscribePacket)); + var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, subscribeResult.ResponsePacket); EnqueueRetainedMessages(subscribePacket); + + if (subscribeResult.CloseConnection) + { + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttDisconnectPacket()); + Stop(); + } } else if (packet is MqttUnsubscribePacket unsubscribePacket) { - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Unsubscribe(unsubscribePacket)); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, _subscriptionsManager.Unsubscribe(unsubscribePacket)); } else if (packet is MqttPublishPacket publishPacket) { - await HandleIncomingPublishPacketAsync(adapter, publishPacket); + await HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken); } else if (packet is MqttPubRelPacket pubRelPacket) { - await HandleIncomingPubRelPacketAsync(adapter, pubRelPacket); + await HandleIncomingPubRelPacketAsync(adapter, pubRelPacket, cancellationToken); } else if (packet is MqttPubRecPacket pubRecPacket) { - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, pubRecPacket.CreateResponse()); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, pubRecPacket.CreateResponse()); } else if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) { @@ -171,7 +178,7 @@ namespace MQTTnet.Core.Server } else if (packet is MqttPingReqPacket) { - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPingRespPacket()); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPingRespPacket()); } else if (packet is MqttDisconnectPacket || packet is MqttConnectPacket) { @@ -193,10 +200,17 @@ namespace MQTTnet.Core.Server } } - private async Task HandleIncomingPublishPacketAsync(IMqttCommunicationAdapter adapter, MqttPublishPacket publishPacket) + private async Task HandleIncomingPublishPacketAsync(IMqttCommunicationAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { var applicationMessage = publishPacket.ToApplicationMessage(); - _options.ApplicationMessageInterceptor?.Invoke(applicationMessage); + + var interceptorContext = new MqttApplicationMessageInterceptorContext + { + ApplicationMessage = applicationMessage + }; + + _options.ApplicationMessageInterceptor?.Invoke(interceptorContext); + applicationMessage = interceptorContext.ApplicationMessage; if (applicationMessage.Retain) { @@ -214,7 +228,7 @@ namespace MQTTnet.Core.Server { _sessionsManager.DispatchApplicationMessage(this, applicationMessage); - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); return; @@ -229,7 +243,7 @@ namespace MQTTnet.Core.Server _sessionsManager.DispatchApplicationMessage(this, applicationMessage); - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); return; @@ -239,14 +253,14 @@ namespace MQTTnet.Core.Server } } - private Task HandleIncomingPubRelPacketAsync(IMqttCommunicationAdapter adapter, MqttPubRelPacket pubRelPacket) + private Task HandleIncomingPubRelPacketAsync(IMqttCommunicationAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { lock (_unacknowledgedPublishPackets) { _unacknowledgedPublishPackets.Remove(pubRelPacket.PacketIdentifier); } - return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }); + return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }); } } } diff --git a/MQTTnet.Core/Server/MqttClientSubscribeResult.cs b/MQTTnet.Core/Server/MqttClientSubscribeResult.cs new file mode 100644 index 0000000..17bfe9a --- /dev/null +++ b/MQTTnet.Core/Server/MqttClientSubscribeResult.cs @@ -0,0 +1,11 @@ +using MQTTnet.Core.Packets; + +namespace MQTTnet.Core.Server +{ + public class MqttClientSubscribeResult + { + public MqttSubAckPacket ResponsePacket { get; set; } + + public bool CloseConnection { get; set; } + } +} diff --git a/MQTTnet.Core/Server/MqttClientSubscriptionsManager.cs b/MQTTnet.Core/Server/MqttClientSubscriptionsManager.cs index 14d63d9..7808276 100644 --- a/MQTTnet.Core/Server/MqttClientSubscriptionsManager.cs +++ b/MQTTnet.Core/Server/MqttClientSubscriptionsManager.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using Microsoft.Extensions.Options; using MQTTnet.Core.Packets; using MQTTnet.Core.Protocol; @@ -8,30 +9,45 @@ namespace MQTTnet.Core.Server public sealed class MqttClientSubscriptionsManager { private readonly Dictionary _subscribedTopics = new Dictionary(); + private readonly MqttServerOptions _options; - public MqttClientSubscriptionsManager() + public MqttClientSubscriptionsManager(IOptions options) { - + _options = options?.Value ?? throw new ArgumentNullException(nameof(options)); } - public MqttSubAckPacket Subscribe(MqttSubscribePacket subscribePacket) + public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket) { if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); var responsePacket = subscribePacket.CreateResponse(); + var closeConnection = false; lock (_subscribedTopics) { foreach (var topicFilter in subscribePacket.TopicFilters) { + var interceptorContext = new MqttSubscriptionInterceptorContext("", topicFilter); + _options.SubscriptionsInterceptor?.Invoke(interceptorContext); + responsePacket.SubscribeReturnCodes.Add(interceptorContext.AcceptSubscription ? MqttSubscribeReturnCode.SuccessMaximumQoS1 : MqttSubscribeReturnCode.Failure); + + if (interceptorContext.CloseConnection) + { + closeConnection = true; + } - - _subscribedTopics[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; - responsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.SuccessMaximumQoS1); // TODO: Add support for QoS 2. + if (interceptorContext.AcceptSubscription) + { + _subscribedTopics[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + } } } - return responsePacket; + return new MqttClientSubscribeResult + { + ResponsePacket = responsePacket, + CloseConnection = closeConnection + }; } public MqttUnsubAckPacket Unsubscribe(MqttUnsubscribePacket unsubscribePacket) diff --git a/MQTTnet.Core/Server/MqttServer.cs b/MQTTnet.Core/Server/MqttServer.cs index ff67e14..8f3800c 100644 --- a/MQTTnet.Core/Server/MqttServer.cs +++ b/MQTTnet.Core/Server/MqttServer.cs @@ -56,8 +56,14 @@ namespace MQTTnet.Core.Server foreach (var applicationMessage in applicationMessages) { - _options.ApplicationMessageInterceptor?.Invoke(applicationMessage); - _clientSessionsManager.DispatchApplicationMessage(null, applicationMessage); + var interceptorContext = new MqttApplicationMessageInterceptorContext + { + ApplicationMessage = applicationMessage + }; + + _options.ApplicationMessageInterceptor?.Invoke(interceptorContext); + + _clientSessionsManager.DispatchApplicationMessage(null, interceptorContext.ApplicationMessage); } } diff --git a/MQTTnet.Core/Server/MqttServerOptions.cs b/MQTTnet.Core/Server/MqttServerOptions.cs index 25f7cc8..4b09676 100644 --- a/MQTTnet.Core/Server/MqttServerOptions.cs +++ b/MQTTnet.Core/Server/MqttServerOptions.cs @@ -16,7 +16,9 @@ namespace MQTTnet.Core.Server public Func ConnectionValidator { get; set; } - public Func ApplicationMessageInterceptor { get; set; } + public Action ApplicationMessageInterceptor { get; set; } + + public Action SubscriptionsInterceptor { get; set; } public IMqttServerStorage Storage { get; set; } } diff --git a/MQTTnet.Core/Server/MqttSubscriptionInterceptorContext.cs b/MQTTnet.Core/Server/MqttSubscriptionInterceptorContext.cs new file mode 100644 index 0000000..93282d3 --- /dev/null +++ b/MQTTnet.Core/Server/MqttSubscriptionInterceptorContext.cs @@ -0,0 +1,19 @@ +namespace MQTTnet.Core.Server +{ + public class MqttSubscriptionInterceptorContext + { + public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter) + { + ClientId = clientId; + TopicFilter = topicFilter; + } + + public string ClientId { get; } + + public TopicFilter TopicFilter { get; } + + public bool AcceptSubscription { get; set; } = true; + + public bool CloseConnection { get; set; } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index d548ff0..d3573a6 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -245,10 +245,9 @@ namespace MQTTnet.Core.Tests [TestMethod] public async Task MqttServer_InterceptMessage() { - MqttApplicationMessage Interceptor(MqttApplicationMessage message) + void Interceptor(MqttApplicationMessageInterceptorContext context) { - message.Payload = Encoding.ASCII.GetBytes("extended"); - return message; + context.ApplicationMessage.Payload = Encoding.ASCII.GetBytes("extended"); } var serverAdapter = new TestMqttServerAdapter(); diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs index e477731..a2d26f9 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs @@ -1,4 +1,5 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.Extensions.Options; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Core.Packets; using MQTTnet.Core.Protocol; using MQTTnet.Core.Server; @@ -11,7 +12,7 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleSuccess() { - var sm = new MqttClientSubscriptionsManager(); + var sm = new MqttClientSubscriptionsManager(new OptionsWrapper(new MqttServerOptions())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce)); @@ -30,7 +31,7 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() { - var sm = new MqttClientSubscriptionsManager(); + var sm = new MqttClientSubscriptionsManager(new OptionsWrapper(new MqttServerOptions())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce)); @@ -49,7 +50,7 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() { - var sm = new MqttClientSubscriptionsManager(); + var sm = new MqttClientSubscriptionsManager(new OptionsWrapper(new MqttServerOptions())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce)); diff --git a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs index a0c6462..0a7d557 100644 --- a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs @@ -35,16 +35,14 @@ namespace MQTTnet.TestApp.NetCore options.Storage = new RetainedMessageHandler(); - options.ApplicationMessageInterceptor = message => + options.ApplicationMessageInterceptor = context => { - if (MqttTopicFilterComparer.IsMatch(message.Topic, "/myTopic/WithTimestamp/#")) + if (MqttTopicFilterComparer.IsMatch(context.ApplicationMessage.Topic, "/myTopic/WithTimestamp/#")) { // Replace the payload with the timestamp. But also extending a JSON // based payload with the timestamp is a suitable use case. - message.Payload = Encoding.UTF8.GetBytes(DateTime.Now.ToString("O")); - } - - return message; + context.ApplicationMessage.Payload = Encoding.UTF8.GetBytes(DateTime.Now.ToString("O")); + } }; });