diff --git a/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs b/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs index 8d378af..00eb0e7 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs @@ -22,16 +22,13 @@ namespace MQTTnet.Server.Mqtt { try { + var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; + var pythonContext = new PythonDictionary { { "client_id", context.ClientId }, + { "session_items", sessionItems }, { "retain", context.ApplicationMessage.Retain }, - { "username", context.Username }, - { "password", context.Password }, - { "raw_password", new Bytes(context.RawPassword ?? new byte[0]) }, - { "clean_session", context.CleanSession}, - { "authentication_method", context.AuthenticationMethod}, - { "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) }, { "accept_publish", context.AcceptPublish }, { "close_connection", context.CloseConnection }, { "topic", context.ApplicationMessage.Topic }, diff --git a/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs b/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs index 3b1a2fc..d002842 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs @@ -9,6 +9,8 @@ namespace MQTTnet.Server.Mqtt { public class MqttServerConnectionValidator : IMqttServerConnectionValidator { + public const string WrappedSessionItemsKey = "WRAPPED_ITEMS"; + private readonly PythonScriptHostService _pythonScriptHostService; private readonly ILogger _logger; @@ -22,6 +24,8 @@ namespace MQTTnet.Server.Mqtt { try { + var sessionItems = new PythonDictionary(); + var pythonContext = new PythonDictionary { { "endpoint", context.Endpoint }, @@ -33,6 +37,7 @@ namespace MQTTnet.Server.Mqtt { "clean_session", context.CleanSession}, { "authentication_method", context.AuthenticationMethod}, { "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) }, + { "session_items", sessionItems }, { "result", PythonConvert.Pythonfy(context.ReasonCode) } }; @@ -40,6 +45,8 @@ namespace MQTTnet.Server.Mqtt _pythonScriptHostService.InvokeOptionalFunction("on_validate_client_connection", pythonContext); context.ReasonCode = PythonConvert.ParseEnum((string)pythonContext["result"]); + + context.SessionItems[WrappedSessionItemsKey] = sessionItems; } catch (Exception exception) { diff --git a/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs b/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs index 2d37f74..ba99e9f 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs @@ -21,14 +21,16 @@ namespace MQTTnet.Server.Mqtt { try { + var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; + var pythonContext = new PythonDictionary { - { "accept_subscription", context.AcceptSubscription }, - { "close_connection", context.CloseConnection }, - { "client_id", context.ClientId }, + { "session_items", sessionItems }, { "topic", context.TopicFilter.Topic }, - { "qos", (int)context.TopicFilter.QualityOfServiceLevel } + { "qos", (int)context.TopicFilter.QualityOfServiceLevel }, + { "accept_subscription", context.AcceptSubscription }, + { "close_connection", context.CloseConnection } }; _pythonScriptHostService.InvokeOptionalFunction("on_intercept_subscription", pythonContext); diff --git a/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs b/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs index c33580e..11efa57 100644 --- a/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs +++ b/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs @@ -1,20 +1,25 @@ using System.Collections.Generic; -using MQTTnet.Packets; namespace MQTTnet.Server { - public class MqttApplicationMessageInterceptorContext : MqttBaseInterceptorContext + public class MqttApplicationMessageInterceptorContext { - public MqttApplicationMessageInterceptorContext(string clientId, IDictionary sessionItems, MqttConnectPacket connectPacket, MqttApplicationMessage applicationMessage) : base(connectPacket, sessionItems) + public MqttApplicationMessageInterceptorContext(string clientId, IDictionary sessionItems, MqttApplicationMessage applicationMessage) { ClientId = clientId; ApplicationMessage = applicationMessage; + SessionItems = sessionItems; } public string ClientId { get; } public MqttApplicationMessage ApplicationMessage { get; set; } + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + public bool AcceptPublish { get; set; } = true; public bool CloseConnection { get; set; } diff --git a/Source/MQTTnet/Server/MqttBaseInterceptorContext.cs b/Source/MQTTnet/Server/MqttBaseInterceptorContext.cs deleted file mode 100644 index 6909d5e..0000000 --- a/Source/MQTTnet/Server/MqttBaseInterceptorContext.cs +++ /dev/null @@ -1,54 +0,0 @@ -using System.Collections.Generic; -using System.Text; -using MQTTnet.Packets; - -namespace MQTTnet.Server -{ - public class MqttBaseInterceptorContext - { - private readonly MqttConnectPacket _connectPacket; - - protected MqttBaseInterceptorContext(MqttConnectPacket connectPacket, IDictionary sessionItems) - { - _connectPacket = connectPacket; - SessionItems = sessionItems; - } - - public string Username => _connectPacket?.Username; - - public byte[] RawPassword => _connectPacket?.Password; - - public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); - - public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage; - - public bool? CleanSession => _connectPacket?.CleanSession; - - public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod; - - public List UserProperties => _connectPacket?.Properties?.UserProperties; - - public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData; - - public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod; - - public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize; - - public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum; - - public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum; - - public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation; - - public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation; - - public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval; - - public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval; - - /// - /// Gets or sets a key/value collection that can be used to share data within the scope of this session. - /// - public IDictionary SessionItems { get; } - } -} \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 73263cb..d165001 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -54,7 +54,7 @@ namespace MQTTnet.Server public async Task SubscribeAsync(ICollection topicFilters, MqttRetainedMessagesManager retainedMessagesManager) { - await SubscriptionsManager.SubscribeAsync(topicFilters, null).ConfigureAwait(false); + await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); foreach (var matchingRetainedMessage in matchingRetainedMessages) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 1463ed4..db70e95 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -290,7 +290,7 @@ namespace MQTTnet.Server private async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) { - var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter); + var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary()); var connectionValidator = _options.ConnectionValidator; @@ -372,23 +372,20 @@ namespace MQTTnet.Server string senderClientId; IDictionary sessionItems; - MqttConnectPacket connectPacket; var messageIsFromServer = senderConnection == null; if (messageIsFromServer) { senderClientId = _options.ClientId; sessionItems = _serverSessionItems; - connectPacket = null; } else { senderClientId = senderConnection.ClientId; sessionItems = senderConnection.Session.Items; - connectPacket = senderConnection.ConnectPacket; } - var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, connectPacket, applicationMessage); + var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, applicationMessage); await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); return interceptorContext; } diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index 04d7495..59eafe5 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -40,7 +40,7 @@ namespace MQTTnet.Server foreach (var originalTopicFilter in subscribePacket.TopicFilters) { - var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter, connectPacket).ConfigureAwait(false); + var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter).ConfigureAwait(false); var finalTopicFilter = interceptorContext.TopicFilter; @@ -74,13 +74,13 @@ namespace MQTTnet.Server return result; } - public async Task SubscribeAsync(IEnumerable topicFilters, MqttConnectPacket connectPacket) + public async Task SubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); foreach (var topicFilter in topicFilters) { - var interceptorContext = await InterceptSubscribeAsync(topicFilter, connectPacket).ConfigureAwait(false); + var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); if (!interceptorContext.AcceptSubscription) { continue; @@ -195,9 +195,9 @@ namespace MQTTnet.Server } } - private async Task InterceptSubscribeAsync(TopicFilter topicFilter, MqttConnectPacket connectPacket) + private async Task InterceptSubscribeAsync(TopicFilter topicFilter) { - var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, connectPacket, _clientSession.Items); + var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items); if (_serverOptions.SubscriptionInterceptor != null) { await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); diff --git a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs index 2ab3383..9a5b8b9 100644 --- a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs +++ b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs @@ -1,7 +1,7 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Security.Cryptography.X509Certificates; +using System.Text; using MQTTnet.Adapter; using MQTTnet.Formatter; using MQTTnet.Packets; @@ -9,15 +9,16 @@ using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttConnectionValidatorContext : MqttBaseInterceptorContext + public class MqttConnectionValidatorContext { private readonly MqttConnectPacket _connectPacket; private readonly IMqttChannelAdapter _clientAdapter; - public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) : base(connectPacket, new ConcurrentDictionary()) + public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter, IDictionary sessionItems) { _connectPacket = connectPacket; _clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter)); + SessionItems = sessionItems; } public string ClientId => _connectPacket.ClientId; @@ -29,7 +30,44 @@ namespace MQTTnet.Server public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate; public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion; - + + public string Username => _connectPacket?.Username; + + public byte[] RawPassword => _connectPacket?.Password; + + public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); + + public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage; + + public bool? CleanSession => _connectPacket?.CleanSession; + + public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod; + + public List UserProperties => _connectPacket?.Properties?.UserProperties; + + public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData; + + public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod; + + public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize; + + public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum; + + public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum; + + public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation; + + public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation; + + public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval; + + public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval; + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + /// /// This is used for MQTTv3 only. /// diff --git a/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs b/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs index 74799b2..7e3963b 100644 --- a/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs +++ b/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs @@ -1,21 +1,25 @@ -using System; -using System.Collections.Generic; -using MQTTnet.Packets; +using System.Collections.Generic; namespace MQTTnet.Server { - public class MqttSubscriptionInterceptorContext : MqttBaseInterceptorContext + public class MqttSubscriptionInterceptorContext { - public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, MqttConnectPacket connectPacket, IDictionary sessionItems) : base(connectPacket, sessionItems) + public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, IDictionary sessionItems) { ClientId = clientId; - TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter)); + TopicFilter = topicFilter; + SessionItems = sessionItems; } public string ClientId { get; } public TopicFilter TopicFilter { get; set; } + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + public bool AcceptSubscription { get; set; } = true; public bool CloseConnection { get; set; } diff --git a/Tests/MQTTnet.Core.Tests/Session_Tests.cs b/Tests/MQTTnet.Core.Tests/Session_Tests.cs new file mode 100644 index 0000000..d06bd4e --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Session_Tests.cs @@ -0,0 +1,61 @@ +using System.Text; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Client; +using MQTTnet.Client.Subscribing; +using MQTTnet.Server; +using MQTTnet.Tests.Mockups; + +namespace MQTTnet.Tests +{ + [TestClass] + public class Session_Tests + { + [TestMethod] + public async Task Set_Session_Item() + { + using (var testEnvironment = new TestEnvironment()) + { + var serverOptions = new MqttServerOptionsBuilder() + .WithConnectionValidator(delegate (MqttConnectionValidatorContext context) + { + // Don't validate anything. Just set some session items. + context.SessionItems["can_subscribe_x"] = true; + context.SessionItems["default_payload"] = "Hello World"; + }) + .WithSubscriptionInterceptor(delegate (MqttSubscriptionInterceptorContext context) + { + if (context.TopicFilter.Topic == "x") + { + context.AcceptSubscription = context.SessionItems["can_subscribe_x"] as bool? == true; + } + }) + .WithApplicationMessageInterceptor(delegate (MqttApplicationMessageInterceptorContext context) + { + context.ApplicationMessage.Payload = Encoding.UTF8.GetBytes(context.SessionItems["default_payload"] as string); + }); + + await testEnvironment.StartServerAsync(serverOptions); + + string receivedPayload = null; + + var client = await testEnvironment.ConnectClientAsync(); + client.UseApplicationMessageReceivedHandler(delegate(MqttApplicationMessageReceivedEventArgs args) + { + receivedPayload = args.ApplicationMessage.ConvertPayloadToString(); + }); + + var subscribeResult = await client.SubscribeAsync("x"); + + Assert.AreEqual(MqttClientSubscribeResultCode.GrantedQoS0, subscribeResult.Items[0].ResultCode); + + var client2 = await testEnvironment.ConnectClientAsync(); + await client2.PublishAsync("x"); + + await Task.Delay(1000); + + Assert.AreEqual("Hello World", receivedPayload); + } + } + } +}