diff --git a/Source/MQTTnet.Server/Mqtt/MqttServerService.cs b/Source/MQTTnet.Server/Mqtt/MqttServerService.cs index b8c463f..85c4176 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttServerService.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttServerService.cs @@ -33,6 +33,7 @@ namespace MQTTnet.Server.Mqtt private readonly MqttServerConnectionValidator _mqttConnectionValidator; private readonly IMqttServer _mqttServer; private readonly MqttSubscriptionInterceptor _mqttSubscriptionInterceptor; + private readonly MqttUnsubscriptionInterceptor _mqttUnsubscriptionInterceptor; private readonly PythonScriptHostService _pythonScriptHostService; private readonly MqttWebSocketServerAdapter _webSocketServerAdapter; @@ -45,6 +46,7 @@ namespace MQTTnet.Server.Mqtt MqttClientUnsubscribedTopicHandler mqttClientUnsubscribedTopicHandler, MqttServerConnectionValidator mqttConnectionValidator, MqttSubscriptionInterceptor mqttSubscriptionInterceptor, + MqttUnsubscriptionInterceptor mqttUnsubscriptionInterceptor, MqttApplicationMessageInterceptor mqttApplicationMessageInterceptor, MqttServerStorage mqttServerStorage, PythonScriptHostService pythonScriptHostService, @@ -57,6 +59,7 @@ namespace MQTTnet.Server.Mqtt _mqttClientUnsubscribedTopicHandler = mqttClientUnsubscribedTopicHandler ?? throw new ArgumentNullException(nameof(mqttClientUnsubscribedTopicHandler)); _mqttConnectionValidator = mqttConnectionValidator ?? throw new ArgumentNullException(nameof(mqttConnectionValidator)); _mqttSubscriptionInterceptor = mqttSubscriptionInterceptor ?? throw new ArgumentNullException(nameof(mqttSubscriptionInterceptor)); + _mqttUnsubscriptionInterceptor = mqttUnsubscriptionInterceptor ?? throw new ArgumentNullException(nameof(mqttUnsubscriptionInterceptor)); _mqttApplicationMessageInterceptor = mqttApplicationMessageInterceptor ?? throw new ArgumentNullException(nameof(mqttApplicationMessageInterceptor)); _mqttServerStorage = mqttServerStorage ?? throw new ArgumentNullException(nameof(mqttServerStorage)); _pythonScriptHostService = pythonScriptHostService ?? throw new ArgumentNullException(nameof(pythonScriptHostService)); @@ -178,6 +181,7 @@ namespace MQTTnet.Server.Mqtt .WithConnectionValidator(_mqttConnectionValidator) .WithApplicationMessageInterceptor(_mqttApplicationMessageInterceptor) .WithSubscriptionInterceptor(_mqttSubscriptionInterceptor) + .WithUnsubscriptionInterceptor(_mqttUnsubscriptionInterceptor) .WithStorage(_mqttServerStorage); // Configure unencrypted connections diff --git a/Source/MQTTnet.Server/Mqtt/MqttUnsubscriptionInterceptor.cs b/Source/MQTTnet.Server/Mqtt/MqttUnsubscriptionInterceptor.cs new file mode 100644 index 0000000..1a460af --- /dev/null +++ b/Source/MQTTnet.Server/Mqtt/MqttUnsubscriptionInterceptor.cs @@ -0,0 +1,48 @@ +using System; +using System.Threading.Tasks; +using IronPython.Runtime; +using Microsoft.Extensions.Logging; +using MQTTnet.Server.Scripting; + +namespace MQTTnet.Server.Mqtt +{ + public class MqttUnsubscriptionInterceptor : IMqttServerUnsubscriptionInterceptor + { + private readonly PythonScriptHostService _pythonScriptHostService; + private readonly ILogger _logger; + + public MqttUnsubscriptionInterceptor(PythonScriptHostService pythonScriptHostService, ILogger logger) + { + _pythonScriptHostService = pythonScriptHostService ?? throw new ArgumentNullException(nameof(pythonScriptHostService)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + public Task InterceptUnsubscriptionAsync(MqttUnsubscriptionInterceptorContext context) + { + try + { + var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; + + var pythonContext = new PythonDictionary + { + { "client_id", context.ClientId }, + { "session_items", sessionItems }, + { "topic", context.Topic }, + { "accept_unsubscription", context.AcceptUnsubscription }, + { "close_connection", context.CloseConnection } + }; + + _pythonScriptHostService.InvokeOptionalFunction("on_intercept_unsubscription", pythonContext); + + context.AcceptUnsubscription = (bool)pythonContext["accept_unsubscription"]; + context.CloseConnection = (bool)pythonContext["close_connection"]; + } + catch (Exception exception) + { + _logger.LogError(exception, "Error while intercepting unsubscription."); + } + + return Task.CompletedTask; + } + } +} diff --git a/Source/MQTTnet/Server/IMqttServerOptions.cs b/Source/MQTTnet/Server/IMqttServerOptions.cs index 7c5fde4..2145845 100644 --- a/Source/MQTTnet/Server/IMqttServerOptions.cs +++ b/Source/MQTTnet/Server/IMqttServerOptions.cs @@ -15,6 +15,7 @@ namespace MQTTnet.Server IMqttServerConnectionValidator ConnectionValidator { get; } IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; } + IMqttServerUnsubscriptionInterceptor UnsubscriptionInterceptor { get; } IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; } IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; } diff --git a/Source/MQTTnet/Server/IMqttServerUnsubscriptionInterceptor.cs b/Source/MQTTnet/Server/IMqttServerUnsubscriptionInterceptor.cs new file mode 100644 index 0000000..9669383 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerUnsubscriptionInterceptor.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttServerUnsubscriptionInterceptor + { + Task InterceptUnsubscriptionAsync(MqttUnsubscriptionInterceptorContext context); + } +} diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index c84a018..deeadf4 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -107,9 +107,16 @@ namespace MQTTnet.Server PacketIdentifier = unsubscribePacket.PacketIdentifier }; - lock (_subscriptions) + foreach (var topicFilter in unsubscribePacket.TopicFilters) { - foreach (var topicFilter in unsubscribePacket.TopicFilters) + var interceptorContext = await InterceptUnsubscribeAsync(topicFilter).ConfigureAwait(false); + if (!interceptorContext.AcceptUnsubscription) + { + unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.ImplementationSpecificError); + continue; + } + + lock (_subscriptions) { if (_subscriptions.Remove(topicFilter)) { @@ -130,19 +137,23 @@ namespace MQTTnet.Server return unsubAckPacket; } - public Task UnsubscribeAsync(IEnumerable topicFilters) + public async Task UnsubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - lock (_subscriptions) + foreach (var topicFilter in topicFilters) { - foreach (var topicFilter in topicFilters) + var interceptorContext = await InterceptUnsubscribeAsync(topicFilter).ConfigureAwait(false); + if (!interceptorContext.AcceptUnsubscription) { - _subscriptions.Remove(topicFilter); + continue; } - } - return Task.FromResult(0); + lock (_subscriptions) + { + _subscriptions.Remove(topicFilter); + } + } } public CheckSubscriptionsResult CheckSubscriptions(string topic, MqttQualityOfServiceLevel qosLevel) @@ -206,6 +217,17 @@ namespace MQTTnet.Server return context; } + private async Task InterceptUnsubscribeAsync(string topicFilter) + { + var context = new MqttUnsubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items); + if (_serverOptions.UnsubscriptionInterceptor != null) + { + await _serverOptions.UnsubscriptionInterceptor.InterceptUnsubscriptionAsync(context).ConfigureAwait(false); + } + + return context; + } + private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet subscribedQoSLevels) { MqttQualityOfServiceLevel effectiveQoS; diff --git a/Source/MQTTnet/Server/MqttServerOptions.cs b/Source/MQTTnet/Server/MqttServerOptions.cs index 7147ef8..d5f6737 100644 --- a/Source/MQTTnet/Server/MqttServerOptions.cs +++ b/Source/MQTTnet/Server/MqttServerOptions.cs @@ -26,6 +26,8 @@ namespace MQTTnet.Server public IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; set; } + public IMqttServerUnsubscriptionInterceptor UnsubscriptionInterceptor { get; set; } + public IMqttServerStorage Storage { get; set; } } } diff --git a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs index c25af84..15126a1 100644 --- a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs +++ b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs @@ -155,6 +155,12 @@ namespace MQTTnet.Server return this; } + public MqttServerOptionsBuilder WithUnsubscriptionInterceptor(IMqttServerUnsubscriptionInterceptor value) + { + _options.UnsubscriptionInterceptor = value; + return this; + } + public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action value) { _options.SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(value); diff --git a/Source/MQTTnet/Server/MqttUnsubscriptionInterceptorContext.cs b/Source/MQTTnet/Server/MqttUnsubscriptionInterceptorContext.cs new file mode 100644 index 0000000..b33cbac --- /dev/null +++ b/Source/MQTTnet/Server/MqttUnsubscriptionInterceptorContext.cs @@ -0,0 +1,27 @@ +using System.Collections.Generic; + +namespace MQTTnet.Server +{ + public class MqttUnsubscriptionInterceptorContext + { + public MqttUnsubscriptionInterceptorContext(string clientId, string topic, IDictionary sessionItems) + { + ClientId = clientId; + Topic = topic; + SessionItems = sessionItems; + } + + public string ClientId { get; } + + public string Topic { get; set; } + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + + public bool AcceptUnsubscription { get; set; } = true; + + public bool CloseConnection { get; set; } + } +}