@@ -22,16 +22,13 @@ namespace MQTTnet.Server.Mqtt | |||
{ | |||
try | |||
{ | |||
var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; | |||
var pythonContext = new PythonDictionary | |||
{ | |||
{ "client_id", context.ClientId }, | |||
{ "session_items", sessionItems }, | |||
{ "retain", context.ApplicationMessage.Retain }, | |||
{ "username", context.Username }, | |||
{ "password", context.Password }, | |||
{ "raw_password", new Bytes(context.RawPassword ?? new byte[0]) }, | |||
{ "clean_session", context.CleanSession}, | |||
{ "authentication_method", context.AuthenticationMethod}, | |||
{ "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) }, | |||
{ "accept_publish", context.AcceptPublish }, | |||
{ "close_connection", context.CloseConnection }, | |||
{ "topic", context.ApplicationMessage.Topic }, | |||
@@ -9,6 +9,8 @@ namespace MQTTnet.Server.Mqtt | |||
{ | |||
public class MqttServerConnectionValidator : IMqttServerConnectionValidator | |||
{ | |||
public const string WrappedSessionItemsKey = "WRAPPED_ITEMS"; | |||
private readonly PythonScriptHostService _pythonScriptHostService; | |||
private readonly ILogger _logger; | |||
@@ -22,6 +24,8 @@ namespace MQTTnet.Server.Mqtt | |||
{ | |||
try | |||
{ | |||
var sessionItems = new PythonDictionary(); | |||
var pythonContext = new PythonDictionary | |||
{ | |||
{ "endpoint", context.Endpoint }, | |||
@@ -33,6 +37,7 @@ namespace MQTTnet.Server.Mqtt | |||
{ "clean_session", context.CleanSession}, | |||
{ "authentication_method", context.AuthenticationMethod}, | |||
{ "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) }, | |||
{ "session_items", sessionItems }, | |||
{ "result", PythonConvert.Pythonfy(context.ReasonCode) } | |||
}; | |||
@@ -40,6 +45,8 @@ namespace MQTTnet.Server.Mqtt | |||
_pythonScriptHostService.InvokeOptionalFunction("on_validate_client_connection", pythonContext); | |||
context.ReasonCode = PythonConvert.ParseEnum<MqttConnectReasonCode>((string)pythonContext["result"]); | |||
context.SessionItems[WrappedSessionItemsKey] = sessionItems; | |||
} | |||
catch (Exception exception) | |||
{ | |||
@@ -21,14 +21,16 @@ namespace MQTTnet.Server.Mqtt | |||
{ | |||
try | |||
{ | |||
var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; | |||
var pythonContext = new PythonDictionary | |||
{ | |||
{ "accept_subscription", context.AcceptSubscription }, | |||
{ "close_connection", context.CloseConnection }, | |||
{ "client_id", context.ClientId }, | |||
{ "session_items", sessionItems }, | |||
{ "topic", context.TopicFilter.Topic }, | |||
{ "qos", (int)context.TopicFilter.QualityOfServiceLevel } | |||
{ "qos", (int)context.TopicFilter.QualityOfServiceLevel }, | |||
{ "accept_subscription", context.AcceptSubscription }, | |||
{ "close_connection", context.CloseConnection } | |||
}; | |||
_pythonScriptHostService.InvokeOptionalFunction("on_intercept_subscription", pythonContext); | |||
@@ -1,20 +1,25 @@ | |||
using System.Collections.Generic; | |||
using MQTTnet.Packets; | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttApplicationMessageInterceptorContext : MqttBaseInterceptorContext | |||
public class MqttApplicationMessageInterceptorContext | |||
{ | |||
public MqttApplicationMessageInterceptorContext(string clientId, IDictionary<object, object> sessionItems, MqttConnectPacket connectPacket, MqttApplicationMessage applicationMessage) : base(connectPacket, sessionItems) | |||
public MqttApplicationMessageInterceptorContext(string clientId, IDictionary<object, object> sessionItems, MqttApplicationMessage applicationMessage) | |||
{ | |||
ClientId = clientId; | |||
ApplicationMessage = applicationMessage; | |||
SessionItems = sessionItems; | |||
} | |||
public string ClientId { get; } | |||
public MqttApplicationMessage ApplicationMessage { get; set; } | |||
/// <summary> | |||
/// Gets or sets a key/value collection that can be used to share data within the scope of this session. | |||
/// </summary> | |||
public IDictionary<object, object> SessionItems { get; } | |||
public bool AcceptPublish { get; set; } = true; | |||
public bool CloseConnection { get; set; } | |||
@@ -1,54 +0,0 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using MQTTnet.Packets; | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttBaseInterceptorContext | |||
{ | |||
private readonly MqttConnectPacket _connectPacket; | |||
protected MqttBaseInterceptorContext(MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems) | |||
{ | |||
_connectPacket = connectPacket; | |||
SessionItems = sessionItems; | |||
} | |||
public string Username => _connectPacket?.Username; | |||
public byte[] RawPassword => _connectPacket?.Password; | |||
public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); | |||
public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage; | |||
public bool? CleanSession => _connectPacket?.CleanSession; | |||
public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod; | |||
public List<MqttUserProperty> UserProperties => _connectPacket?.Properties?.UserProperties; | |||
public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData; | |||
public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod; | |||
public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize; | |||
public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum; | |||
public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum; | |||
public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation; | |||
public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation; | |||
public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval; | |||
public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval; | |||
/// <summary> | |||
/// Gets or sets a key/value collection that can be used to share data within the scope of this session. | |||
/// </summary> | |||
public IDictionary<object, object> SessionItems { get; } | |||
} | |||
} |
@@ -54,7 +54,7 @@ namespace MQTTnet.Server | |||
public async Task SubscribeAsync(ICollection<TopicFilter> topicFilters, MqttRetainedMessagesManager retainedMessagesManager) | |||
{ | |||
await SubscriptionsManager.SubscribeAsync(topicFilters, null).ConfigureAwait(false); | |||
await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); | |||
var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); | |||
foreach (var matchingRetainedMessage in matchingRetainedMessages) | |||
@@ -290,7 +290,7 @@ namespace MQTTnet.Server | |||
private async Task<MqttConnectionValidatorContext> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) | |||
{ | |||
var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter); | |||
var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary<object, object>()); | |||
var connectionValidator = _options.ConnectionValidator; | |||
@@ -372,23 +372,20 @@ namespace MQTTnet.Server | |||
string senderClientId; | |||
IDictionary<object, object> sessionItems; | |||
MqttConnectPacket connectPacket; | |||
var messageIsFromServer = senderConnection == null; | |||
if (messageIsFromServer) | |||
{ | |||
senderClientId = _options.ClientId; | |||
sessionItems = _serverSessionItems; | |||
connectPacket = null; | |||
} | |||
else | |||
{ | |||
senderClientId = senderConnection.ClientId; | |||
sessionItems = senderConnection.Session.Items; | |||
connectPacket = senderConnection.ConnectPacket; | |||
} | |||
var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, connectPacket, applicationMessage); | |||
var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, applicationMessage); | |||
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); | |||
return interceptorContext; | |||
} | |||
@@ -40,7 +40,7 @@ namespace MQTTnet.Server | |||
foreach (var originalTopicFilter in subscribePacket.TopicFilters) | |||
{ | |||
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter, connectPacket).ConfigureAwait(false); | |||
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter).ConfigureAwait(false); | |||
var finalTopicFilter = interceptorContext.TopicFilter; | |||
@@ -74,13 +74,13 @@ namespace MQTTnet.Server | |||
return result; | |||
} | |||
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters, MqttConnectPacket connectPacket) | |||
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters) | |||
{ | |||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||
foreach (var topicFilter in topicFilters) | |||
{ | |||
var interceptorContext = await InterceptSubscribeAsync(topicFilter, connectPacket).ConfigureAwait(false); | |||
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); | |||
if (!interceptorContext.AcceptSubscription) | |||
{ | |||
continue; | |||
@@ -195,9 +195,9 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter, MqttConnectPacket connectPacket) | |||
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter) | |||
{ | |||
var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, connectPacket, _clientSession.Items); | |||
var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items); | |||
if (_serverOptions.SubscriptionInterceptor != null) | |||
{ | |||
await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); | |||
@@ -1,7 +1,7 @@ | |||
using System; | |||
using System.Collections.Concurrent; | |||
using System.Collections.Generic; | |||
using System.Security.Cryptography.X509Certificates; | |||
using System.Text; | |||
using MQTTnet.Adapter; | |||
using MQTTnet.Formatter; | |||
using MQTTnet.Packets; | |||
@@ -9,15 +9,16 @@ using MQTTnet.Protocol; | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttConnectionValidatorContext : MqttBaseInterceptorContext | |||
public class MqttConnectionValidatorContext | |||
{ | |||
private readonly MqttConnectPacket _connectPacket; | |||
private readonly IMqttChannelAdapter _clientAdapter; | |||
public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) : base(connectPacket, new ConcurrentDictionary<object, object>()) | |||
public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter, IDictionary<object, object> sessionItems) | |||
{ | |||
_connectPacket = connectPacket; | |||
_clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter)); | |||
SessionItems = sessionItems; | |||
} | |||
public string ClientId => _connectPacket.ClientId; | |||
@@ -29,7 +30,44 @@ namespace MQTTnet.Server | |||
public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate; | |||
public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion; | |||
public string Username => _connectPacket?.Username; | |||
public byte[] RawPassword => _connectPacket?.Password; | |||
public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); | |||
public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage; | |||
public bool? CleanSession => _connectPacket?.CleanSession; | |||
public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod; | |||
public List<MqttUserProperty> UserProperties => _connectPacket?.Properties?.UserProperties; | |||
public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData; | |||
public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod; | |||
public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize; | |||
public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum; | |||
public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum; | |||
public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation; | |||
public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation; | |||
public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval; | |||
public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval; | |||
/// <summary> | |||
/// Gets or sets a key/value collection that can be used to share data within the scope of this session. | |||
/// </summary> | |||
public IDictionary<object, object> SessionItems { get; } | |||
/// <summary> | |||
/// This is used for MQTTv3 only. | |||
/// </summary> | |||
@@ -1,21 +1,25 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using MQTTnet.Packets; | |||
using System.Collections.Generic; | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttSubscriptionInterceptorContext : MqttBaseInterceptorContext | |||
public class MqttSubscriptionInterceptorContext | |||
{ | |||
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems) : base(connectPacket, sessionItems) | |||
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, IDictionary<object, object> sessionItems) | |||
{ | |||
ClientId = clientId; | |||
TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter)); | |||
TopicFilter = topicFilter; | |||
SessionItems = sessionItems; | |||
} | |||
public string ClientId { get; } | |||
public TopicFilter TopicFilter { get; set; } | |||
/// <summary> | |||
/// Gets or sets a key/value collection that can be used to share data within the scope of this session. | |||
/// </summary> | |||
public IDictionary<object, object> SessionItems { get; } | |||
public bool AcceptSubscription { get; set; } = true; | |||
public bool CloseConnection { get; set; } | |||
@@ -0,0 +1,61 @@ | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Client; | |||
using MQTTnet.Client.Subscribing; | |||
using MQTTnet.Server; | |||
using MQTTnet.Tests.Mockups; | |||
namespace MQTTnet.Tests | |||
{ | |||
[TestClass] | |||
public class Session_Tests | |||
{ | |||
[TestMethod] | |||
public async Task Set_Session_Item() | |||
{ | |||
using (var testEnvironment = new TestEnvironment()) | |||
{ | |||
var serverOptions = new MqttServerOptionsBuilder() | |||
.WithConnectionValidator(delegate (MqttConnectionValidatorContext context) | |||
{ | |||
// Don't validate anything. Just set some session items. | |||
context.SessionItems["can_subscribe_x"] = true; | |||
context.SessionItems["default_payload"] = "Hello World"; | |||
}) | |||
.WithSubscriptionInterceptor(delegate (MqttSubscriptionInterceptorContext context) | |||
{ | |||
if (context.TopicFilter.Topic == "x") | |||
{ | |||
context.AcceptSubscription = context.SessionItems["can_subscribe_x"] as bool? == true; | |||
} | |||
}) | |||
.WithApplicationMessageInterceptor(delegate (MqttApplicationMessageInterceptorContext context) | |||
{ | |||
context.ApplicationMessage.Payload = Encoding.UTF8.GetBytes(context.SessionItems["default_payload"] as string); | |||
}); | |||
await testEnvironment.StartServerAsync(serverOptions); | |||
string receivedPayload = null; | |||
var client = await testEnvironment.ConnectClientAsync(); | |||
client.UseApplicationMessageReceivedHandler(delegate(MqttApplicationMessageReceivedEventArgs args) | |||
{ | |||
receivedPayload = args.ApplicationMessage.ConvertPayloadToString(); | |||
}); | |||
var subscribeResult = await client.SubscribeAsync("x"); | |||
Assert.AreEqual(MqttClientSubscribeResultCode.GrantedQoS0, subscribeResult.Items[0].ResultCode); | |||
var client2 = await testEnvironment.ConnectClientAsync(); | |||
await client2.PublishAsync("x"); | |||
await Task.Delay(1000); | |||
Assert.AreEqual("Hello World", receivedPayload); | |||
} | |||
} | |||
} | |||
} |