@@ -11,6 +11,8 @@ | |||||
<requireLicenseAcceptance>false</requireLicenseAcceptance> | <requireLicenseAcceptance>false</requireLicenseAcceptance> | ||||
<description>MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker) and supports v3.1.0, v3.1.1 and v5.0.0 of the MQTT protocol.</description> | <description>MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker) and supports v3.1.0, v3.1.1 and v5.0.0 of the MQTT protocol.</description> | ||||
<releaseNotes> | <releaseNotes> | ||||
* [Server] Added items dictionary to client session in order to share data across interceptors as along as the session exists. | |||||
* [Server] Exposed CONNECT packet properties in Application Message and Subscription interceptor. | |||||
* [MQTTnet.Server] Added REST API for publishing basic messages. | * [MQTTnet.Server] Added REST API for publishing basic messages. | ||||
</releaseNotes> | </releaseNotes> | ||||
<copyright>Copyright Christian Kratky 2016-2019</copyright> | <copyright>Copyright Christian Kratky 2016-2019</copyright> | ||||
@@ -24,12 +24,18 @@ namespace MQTTnet.Server.Mqtt | |||||
{ | { | ||||
var pythonContext = new PythonDictionary | var pythonContext = new PythonDictionary | ||||
{ | { | ||||
{ "client_id", context.ClientId }, | |||||
{ "retain", context.ApplicationMessage.Retain }, | |||||
{ "username", context.Username }, | |||||
{ "password", context.Password }, | |||||
{ "raw_password", new Bytes(context.RawPassword ?? new byte[0]) }, | |||||
{ "clean_session", context.CleanSession}, | |||||
{ "authentication_method", context.AuthenticationMethod}, | |||||
{ "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) }, | |||||
{ "accept_publish", context.AcceptPublish }, | { "accept_publish", context.AcceptPublish }, | ||||
{ "close_connection", context.CloseConnection }, | { "close_connection", context.CloseConnection }, | ||||
{ "client_id", context.ClientId }, | |||||
{ "topic", context.ApplicationMessage.Topic }, | { "topic", context.ApplicationMessage.Topic }, | ||||
{ "qos", (int)context.ApplicationMessage.QualityOfServiceLevel }, | |||||
{ "retain", context.ApplicationMessage.Retain } | |||||
{ "qos", (int)context.ApplicationMessage.QualityOfServiceLevel } | |||||
}; | }; | ||||
_pythonScriptHostService.InvokeOptionalFunction("on_intercept_application_message", pythonContext); | _pythonScriptHostService.InvokeOptionalFunction("on_intercept_application_message", pythonContext); | ||||
@@ -1,8 +1,11 @@ | |||||
namespace MQTTnet.Server | |||||
using System.Collections.Generic; | |||||
using MQTTnet.Packets; | |||||
namespace MQTTnet.Server | |||||
{ | { | ||||
public class MqttApplicationMessageInterceptorContext | |||||
public class MqttApplicationMessageInterceptorContext : MqttBaseInterceptorContext | |||||
{ | { | ||||
public MqttApplicationMessageInterceptorContext(string clientId, MqttApplicationMessage applicationMessage) | |||||
public MqttApplicationMessageInterceptorContext(string clientId, IDictionary<object, object> sessionItems, MqttConnectPacket connectPacket, MqttApplicationMessage applicationMessage) : base(connectPacket, sessionItems) | |||||
{ | { | ||||
ClientId = clientId; | ClientId = clientId; | ||||
ApplicationMessage = applicationMessage; | ApplicationMessage = applicationMessage; | ||||
@@ -0,0 +1,54 @@ | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using MQTTnet.Packets; | |||||
namespace MQTTnet.Server | |||||
{ | |||||
public class MqttBaseInterceptorContext | |||||
{ | |||||
private readonly MqttConnectPacket _connectPacket; | |||||
protected MqttBaseInterceptorContext(MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems) | |||||
{ | |||||
_connectPacket = connectPacket; | |||||
SessionItems = sessionItems; | |||||
} | |||||
public string Username => _connectPacket?.Username; | |||||
public byte[] RawPassword => _connectPacket?.Password; | |||||
public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); | |||||
public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage; | |||||
public bool? CleanSession => _connectPacket?.CleanSession; | |||||
public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod; | |||||
public List<MqttUserProperty> UserProperties => _connectPacket?.Properties?.UserProperties; | |||||
public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData; | |||||
public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod; | |||||
public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize; | |||||
public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum; | |||||
public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum; | |||||
public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation; | |||||
public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation; | |||||
public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval; | |||||
public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval; | |||||
/// <summary> | |||||
/// Gets or sets a key/value collection that can be used to share data within the scope of this session. | |||||
/// </summary> | |||||
public IDictionary<object, object> SessionItems { get; } | |||||
} | |||||
} |
@@ -31,7 +31,6 @@ namespace MQTTnet.Server | |||||
private readonly IMqttChannelAdapter _channelAdapter; | private readonly IMqttChannelAdapter _channelAdapter; | ||||
private readonly IMqttDataConverter _dataConverter; | private readonly IMqttDataConverter _dataConverter; | ||||
private readonly string _endpoint; | private readonly string _endpoint; | ||||
private readonly MqttConnectPacket _connectPacket; | |||||
private readonly DateTime _connectedTimestamp; | private readonly DateTime _connectedTimestamp; | ||||
private Task<MqttClientDisconnectType> _packageReceiverTask; | private Task<MqttClientDisconnectType> _packageReceiverTask; | ||||
@@ -60,22 +59,24 @@ namespace MQTTnet.Server | |||||
_channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); | _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); | ||||
_dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; | _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; | ||||
_endpoint = _channelAdapter.Endpoint; | _endpoint = _channelAdapter.Endpoint; | ||||
_connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); | |||||
ConnectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); | |||||
if (logger == null) throw new ArgumentNullException(nameof(logger)); | if (logger == null) throw new ArgumentNullException(nameof(logger)); | ||||
_logger = logger.CreateChildLogger(nameof(MqttClientConnection)); | _logger = logger.CreateChildLogger(nameof(MqttClientConnection)); | ||||
_keepAliveMonitor = new MqttClientKeepAliveMonitor(_connectPacket.ClientId, StopAsync, _logger); | |||||
_keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, StopAsync, _logger); | |||||
_connectedTimestamp = DateTime.UtcNow; | _connectedTimestamp = DateTime.UtcNow; | ||||
_lastPacketReceivedTimestamp = _connectedTimestamp; | _lastPacketReceivedTimestamp = _connectedTimestamp; | ||||
_lastNonKeepAlivePacketReceivedTimestamp = _lastPacketReceivedTimestamp; | _lastNonKeepAlivePacketReceivedTimestamp = _lastPacketReceivedTimestamp; | ||||
} | } | ||||
public string ClientId => _connectPacket.ClientId; | |||||
public MqttConnectPacket ConnectPacket { get; } | |||||
public MqttClientSession Session { get; } | |||||
public string ClientId => ConnectPacket.ClientId; | |||||
public MqttClientSession Session { get; } | |||||
public async Task StopAsync() | public async Task StopAsync() | ||||
{ | { | ||||
StopInternal(); | StopInternal(); | ||||
@@ -133,12 +134,12 @@ namespace MQTTnet.Server | |||||
_channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; | _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; | ||||
_channelAdapter.ReadingPacketCompletedCallback = OnAdapterReadingPacketCompleted; | _channelAdapter.ReadingPacketCompletedCallback = OnAdapterReadingPacketCompleted; | ||||
Session.WillMessage = _connectPacket.WillMessage; | |||||
Session.WillMessage = ConnectPacket.WillMessage; | |||||
Task.Run(() => SendPendingPacketsAsync(_cancellationToken.Token), _cancellationToken.Token).Forget(_logger); | Task.Run(() => SendPendingPacketsAsync(_cancellationToken.Token), _cancellationToken.Token).Forget(_logger); | ||||
// TODO: Change to single thread in SessionManager. Or use SessionManager and stats from KeepAliveMonitor. | // TODO: Change to single thread in SessionManager. Or use SessionManager and stats from KeepAliveMonitor. | ||||
_keepAliveMonitor.Start(_connectPacket.KeepAlivePeriod, _cancellationToken.Token); | |||||
_keepAliveMonitor.Start(ConnectPacket.KeepAlivePeriod, _cancellationToken.Token); | |||||
await SendAsync( | await SendAsync( | ||||
new MqttConnAckPacket | new MqttConnAckPacket | ||||
@@ -271,7 +272,7 @@ namespace MQTTnet.Server | |||||
private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) | private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) | ||||
{ | { | ||||
// TODO: Let the channel adapter create the packet. | // TODO: Let the channel adapter create the packet. | ||||
var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); | |||||
var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); | |||||
await SendAsync(subscribeResult.ResponsePacket).ConfigureAwait(false); | await SendAsync(subscribeResult.ResponsePacket).ConfigureAwait(false); | ||||
@@ -12,11 +12,12 @@ namespace MQTTnet.Server | |||||
private readonly DateTime _createdTimestamp = DateTime.UtcNow; | private readonly DateTime _createdTimestamp = DateTime.UtcNow; | ||||
public MqttClientSession(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger) | |||||
public MqttClientSession(string clientId, IDictionary<object, object> items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger) | |||||
{ | { | ||||
ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); | ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); | ||||
Items = items ?? throw new ArgumentNullException(nameof(items)); | |||||
SubscriptionsManager = new MqttClientSubscriptionsManager(clientId, eventDispatcher, serverOptions); | |||||
SubscriptionsManager = new MqttClientSubscriptionsManager(this, eventDispatcher, serverOptions); | |||||
ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); | ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); | ||||
if (logger == null) throw new ArgumentNullException(nameof(logger)); | if (logger == null) throw new ArgumentNullException(nameof(logger)); | ||||
@@ -33,6 +34,11 @@ namespace MQTTnet.Server | |||||
public MqttClientSessionApplicationMessagesQueue ApplicationMessagesQueue { get; } | public MqttClientSessionApplicationMessagesQueue ApplicationMessagesQueue { get; } | ||||
/// <summary> | |||||
/// Gets or sets a key/value collection that can be used to share data within the scope of this session. | |||||
/// </summary> | |||||
public IDictionary<object, object> Items { get; } | |||||
public void EnqueueApplicationMessage(MqttApplicationMessage applicationMessage, string senderClientId, bool isRetainedApplicationMessage) | public void EnqueueApplicationMessage(MqttApplicationMessage applicationMessage, string senderClientId, bool isRetainedApplicationMessage) | ||||
{ | { | ||||
var checkSubscriptionsResult = SubscriptionsManager.CheckSubscriptions(applicationMessage.Topic, applicationMessage.QualityOfServiceLevel); | var checkSubscriptionsResult = SubscriptionsManager.CheckSubscriptions(applicationMessage.Topic, applicationMessage.QualityOfServiceLevel); | ||||
@@ -48,7 +54,7 @@ namespace MQTTnet.Server | |||||
public async Task SubscribeAsync(ICollection<TopicFilter> topicFilters, MqttRetainedMessagesManager retainedMessagesManager) | public async Task SubscribeAsync(ICollection<TopicFilter> topicFilters, MqttRetainedMessagesManager retainedMessagesManager) | ||||
{ | { | ||||
await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); | |||||
await SubscriptionsManager.SubscribeAsync(topicFilters, null).ConfigureAwait(false); | |||||
var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); | var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); | ||||
foreach (var matchingRetainedMessage in matchingRetainedMessages) | foreach (var matchingRetainedMessage in matchingRetainedMessages) | ||||
@@ -20,7 +20,8 @@ namespace MQTTnet.Server | |||||
private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1); | private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1); | ||||
private readonly ConcurrentDictionary<string, MqttClientConnection> _connections = new ConcurrentDictionary<string, MqttClientConnection>(); | private readonly ConcurrentDictionary<string, MqttClientConnection> _connections = new ConcurrentDictionary<string, MqttClientConnection>(); | ||||
private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>(); | private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>(); | ||||
private readonly IDictionary<object, object> _serverSessionItems = new ConcurrentDictionary<object, object>(); | |||||
private readonly CancellationToken _cancellationToken; | private readonly CancellationToken _cancellationToken; | ||||
private readonly MqttServerEventDispatcher _eventDispatcher; | private readonly MqttServerEventDispatcher _eventDispatcher; | ||||
@@ -241,19 +242,19 @@ namespace MQTTnet.Server | |||||
clientId = connectPacket.ClientId; | clientId = connectPacket.ClientId; | ||||
var validatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); | |||||
var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); | |||||
if (validatorContext.ReasonCode != MqttConnectReasonCode.Success) | |||||
if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) | |||||
{ | { | ||||
// Send failure response here without preparing a session. The result for a successful connect | // Send failure response here without preparing a session. The result for a successful connect | ||||
// will be sent from the session itself. | // will be sent from the session itself. | ||||
var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(validatorContext); | |||||
var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); | |||||
await channelAdapter.SendPacketAsync(connAckPacket, _options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); | await channelAdapter.SendPacketAsync(connAckPacket, _options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); | ||||
return; | return; | ||||
} | } | ||||
var connection = await CreateConnectionAsync(channelAdapter, connectPacket).ConfigureAwait(false); | |||||
var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); | |||||
await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); | await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); | ||||
@@ -302,8 +303,7 @@ namespace MQTTnet.Server | |||||
await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false); | await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false); | ||||
// Check the client ID and set a random one if supported. | // Check the client ID and set a random one if supported. | ||||
if (string.IsNullOrEmpty(connectPacket.ClientId) && | |||||
channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500) | |||||
if (string.IsNullOrEmpty(connectPacket.ClientId) && channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500) | |||||
{ | { | ||||
connectPacket.ClientId = context.AssignedClientIdentifier; | connectPacket.ClientId = context.AssignedClientIdentifier; | ||||
} | } | ||||
@@ -316,7 +316,7 @@ namespace MQTTnet.Server | |||||
return context; | return context; | ||||
} | } | ||||
private async Task<MqttClientConnection> CreateConnectionAsync(IMqttChannelAdapter channelAdapter, MqttConnectPacket connectPacket) | |||||
private async Task<MqttClientConnection> CreateConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) | |||||
{ | { | ||||
await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false); | await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false); | ||||
try | try | ||||
@@ -345,7 +345,7 @@ namespace MQTTnet.Server | |||||
if (session == null) | if (session == null) | ||||
{ | { | ||||
session = new MqttClientSession(connectPacket.ClientId, _eventDispatcher, _options, _logger); | |||||
session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _logger); | |||||
_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | ||||
} | } | ||||
@@ -362,7 +362,7 @@ namespace MQTTnet.Server | |||||
} | } | ||||
} | } | ||||
private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection sender, MqttApplicationMessage applicationMessage) | |||||
private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage) | |||||
{ | { | ||||
var interceptor = _options.ApplicationMessageInterceptor; | var interceptor = _options.ApplicationMessageInterceptor; | ||||
if (interceptor == null) | if (interceptor == null) | ||||
@@ -370,13 +370,25 @@ namespace MQTTnet.Server | |||||
return null; | return null; | ||||
} | } | ||||
var senderClientId = sender?.ClientId; | |||||
if (sender == null) | |||||
string senderClientId; | |||||
IDictionary<object, object> sessionItems; | |||||
MqttConnectPacket connectPacket; | |||||
var messageIsFromServer = senderConnection == null; | |||||
if (messageIsFromServer) | |||||
{ | { | ||||
senderClientId = _options.ClientId; | senderClientId = _options.ClientId; | ||||
sessionItems = _serverSessionItems; | |||||
connectPacket = null; | |||||
} | |||||
else | |||||
{ | |||||
senderClientId = senderConnection.ClientId; | |||||
sessionItems = senderConnection.Session.Items; | |||||
connectPacket = senderConnection.ConnectPacket; | |||||
} | } | ||||
var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, applicationMessage); | |||||
var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, connectPacket, applicationMessage); | |||||
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); | await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); | ||||
return interceptorContext; | return interceptorContext; | ||||
} | } | ||||
@@ -10,20 +10,23 @@ namespace MQTTnet.Server | |||||
public class MqttClientSubscriptionsManager | public class MqttClientSubscriptionsManager | ||||
{ | { | ||||
private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>(); | private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>(); | ||||
private readonly IMqttServerOptions _options; | |||||
private readonly MqttClientSession _clientSession; | |||||
private readonly IMqttServerOptions _serverOptions; | |||||
private readonly MqttServerEventDispatcher _eventDispatcher; | private readonly MqttServerEventDispatcher _eventDispatcher; | ||||
private readonly string _clientId; | |||||
public MqttClientSubscriptionsManager(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions options) | |||||
public MqttClientSubscriptionsManager(MqttClientSession clientSession, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions) | |||||
{ | { | ||||
_clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); | |||||
_options = options ?? throw new ArgumentNullException(nameof(options)); | |||||
_clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); | |||||
// TODO: Consider removing the server options here and build a new class "ISubscriptionInterceptor" and just pass it. The instance is generated in the root server class upon start. | |||||
_serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); | |||||
_eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher)); | _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher)); | ||||
} | } | ||||
public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket) | |||||
public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket, MqttConnectPacket connectPacket) | |||||
{ | { | ||||
if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); | if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); | ||||
if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); | |||||
var result = new MqttClientSubscribeResult | var result = new MqttClientSubscribeResult | ||||
{ | { | ||||
@@ -37,7 +40,7 @@ namespace MQTTnet.Server | |||||
foreach (var originalTopicFilter in subscribePacket.TopicFilters) | foreach (var originalTopicFilter in subscribePacket.TopicFilters) | ||||
{ | { | ||||
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter).ConfigureAwait(false); | |||||
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter, connectPacket).ConfigureAwait(false); | |||||
var finalTopicFilter = interceptorContext.TopicFilter; | var finalTopicFilter = interceptorContext.TopicFilter; | ||||
@@ -64,18 +67,20 @@ namespace MQTTnet.Server | |||||
_subscriptions[finalTopicFilter.Topic] = finalTopicFilter.QualityOfServiceLevel; | _subscriptions[finalTopicFilter.Topic] = finalTopicFilter.QualityOfServiceLevel; | ||||
} | } | ||||
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, finalTopicFilter).ConfigureAwait(false); | |||||
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false); | |||||
} | } | ||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters) | |||||
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters, MqttConnectPacket connectPacket) | |||||
{ | { | ||||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||||
foreach (var topicFilter in topicFilters) | foreach (var topicFilter in topicFilters) | ||||
{ | { | ||||
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); | |||||
var interceptorContext = await InterceptSubscribeAsync(topicFilter, connectPacket).ConfigureAwait(false); | |||||
if (!interceptorContext.AcceptSubscription) | if (!interceptorContext.AcceptSubscription) | ||||
{ | { | ||||
continue; | continue; | ||||
@@ -88,7 +93,7 @@ namespace MQTTnet.Server | |||||
_subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; | _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; | ||||
} | } | ||||
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false); | |||||
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -119,7 +124,7 @@ namespace MQTTnet.Server | |||||
foreach (var topicFilter in unsubscribePacket.TopicFilters) | foreach (var topicFilter in unsubscribePacket.TopicFilters) | ||||
{ | { | ||||
await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false); | |||||
await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); | |||||
} | } | ||||
return unsubAckPacket; | return unsubAckPacket; | ||||
@@ -190,12 +195,12 @@ namespace MQTTnet.Server | |||||
} | } | ||||
} | } | ||||
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter) | |||||
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter, MqttConnectPacket connectPacket) | |||||
{ | { | ||||
var context = new MqttSubscriptionInterceptorContext(_clientId, topicFilter); | |||||
if (_options.SubscriptionInterceptor != null) | |||||
var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, connectPacket, _clientSession.Items); | |||||
if (_serverOptions.SubscriptionInterceptor != null) | |||||
{ | { | ||||
await _options.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); | |||||
await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); | |||||
} | } | ||||
return context; | return context; | ||||
@@ -1,7 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Concurrent; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Security.Cryptography.X509Certificates; | using System.Security.Cryptography.X509Certificates; | ||||
using System.Text; | |||||
using MQTTnet.Adapter; | using MQTTnet.Adapter; | ||||
using MQTTnet.Formatter; | using MQTTnet.Formatter; | ||||
using MQTTnet.Packets; | using MQTTnet.Packets; | ||||
@@ -9,51 +9,19 @@ using MQTTnet.Protocol; | |||||
namespace MQTTnet.Server | namespace MQTTnet.Server | ||||
{ | { | ||||
public class MqttConnectionValidatorContext | |||||
public class MqttConnectionValidatorContext : MqttBaseInterceptorContext | |||||
{ | { | ||||
private readonly MqttConnectPacket _connectPacket; | private readonly MqttConnectPacket _connectPacket; | ||||
private readonly IMqttChannelAdapter _clientAdapter; | private readonly IMqttChannelAdapter _clientAdapter; | ||||
public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) | |||||
public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) : base(connectPacket, new ConcurrentDictionary<object, object>()) | |||||
{ | { | ||||
_connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); | |||||
_connectPacket = connectPacket; | |||||
_clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter)); | _clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter)); | ||||
} | } | ||||
public string ClientId => _connectPacket.ClientId; | public string ClientId => _connectPacket.ClientId; | ||||
public string Username => _connectPacket.Username; | |||||
public byte[] RawPassword => _connectPacket.Password; | |||||
public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); | |||||
public MqttApplicationMessage WillMessage => _connectPacket.WillMessage; | |||||
public bool CleanSession => _connectPacket.CleanSession; | |||||
public ushort KeepAlivePeriod => _connectPacket.KeepAlivePeriod; | |||||
public List<MqttUserProperty> UserProperties => _connectPacket.Properties?.UserProperties; | |||||
public byte[] AuthenticationData => _connectPacket.Properties?.AuthenticationData; | |||||
public string AuthenticationMethod => _connectPacket.Properties?.AuthenticationMethod; | |||||
public uint? MaximumPacketSize => _connectPacket.Properties?.MaximumPacketSize; | |||||
public ushort? ReceiveMaximum => _connectPacket.Properties?.ReceiveMaximum; | |||||
public ushort? TopicAliasMaximum => _connectPacket.Properties?.TopicAliasMaximum; | |||||
public bool? RequestProblemInformation => _connectPacket.Properties?.RequestProblemInformation; | |||||
public bool? RequestResponseInformation => _connectPacket.Properties?.RequestResponseInformation; | |||||
public uint? SessionExpiryInterval => _connectPacket.Properties?.SessionExpiryInterval; | |||||
public uint? WillDelayInterval => _connectPacket.Properties?.WillDelayInterval; | |||||
public string Endpoint => _clientAdapter.Endpoint; | public string Endpoint => _clientAdapter.Endpoint; | ||||
public bool IsSecureConnection => _clientAdapter.IsSecureConnection; | public bool IsSecureConnection => _clientAdapter.IsSecureConnection; | ||||
@@ -61,7 +29,7 @@ namespace MQTTnet.Server | |||||
public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate; | public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate; | ||||
public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion; | public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion; | ||||
/// <summary> | /// <summary> | ||||
/// This is used for MQTTv3 only. | /// This is used for MQTTv3 only. | ||||
/// </summary> | /// </summary> | ||||
@@ -1,10 +1,12 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using MQTTnet.Packets; | |||||
namespace MQTTnet.Server | namespace MQTTnet.Server | ||||
{ | { | ||||
public class MqttSubscriptionInterceptorContext | |||||
public class MqttSubscriptionInterceptorContext : MqttBaseInterceptorContext | |||||
{ | { | ||||
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter) | |||||
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems) : base(connectPacket, sessionItems) | |||||
{ | { | ||||
ClientId = clientId; | ClientId = clientId; | ||||
TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter)); | TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter)); | ||||
@@ -13,7 +15,7 @@ namespace MQTTnet.Server | |||||
public string ClientId { get; } | public string ClientId { get; } | ||||
public TopicFilter TopicFilter { get; set; } | public TopicFilter TopicFilter { get; set; } | ||||
public bool AcceptSubscription { get; set; } = true; | public bool AcceptSubscription { get; set; } = true; | ||||
public bool CloseConnection { get; set; } | public bool CloseConnection { get; set; } | ||||
@@ -1,4 +1,6 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System.Collections.Concurrent; | |||||
using System.Threading.Tasks; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using MQTTnet.Packets; | using MQTTnet.Packets; | ||||
using MQTTnet.Protocol; | using MQTTnet.Protocol; | ||||
using MQTTnet.Server; | using MQTTnet.Server; | ||||
@@ -10,14 +12,17 @@ namespace MQTTnet.Tests | |||||
public class MqttSubscriptionsManager_Tests | public class MqttSubscriptionsManager_Tests | ||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void MqttSubscriptionsManager_SubscribeSingleSuccess() | |||||
public async Task MqttSubscriptionsManager_SubscribeSingleSuccess() | |||||
{ | { | ||||
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var sp = new MqttSubscribePacket(); | var sp = new MqttSubscribePacket(); | ||||
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); | sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); | ||||
sm.SubscribeAsync(sp).GetAwaiter().GetResult(); | |||||
await sm.SubscribeAsync(sp, new MqttConnectPacket()); | |||||
var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce); | var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce); | ||||
Assert.IsTrue(result.IsSubscribed); | Assert.IsTrue(result.IsSubscribed); | ||||
@@ -25,14 +30,17 @@ namespace MQTTnet.Tests | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() | |||||
public async Task MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() | |||||
{ | { | ||||
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var sp = new MqttSubscribePacket(); | var sp = new MqttSubscribePacket(); | ||||
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); | sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); | ||||
sm.SubscribeAsync(sp).GetAwaiter().GetResult(); | |||||
await sm.SubscribeAsync(sp, new MqttConnectPacket()); | |||||
var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); | var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); | ||||
Assert.IsTrue(result.IsSubscribed); | Assert.IsTrue(result.IsSubscribed); | ||||
@@ -40,15 +48,18 @@ namespace MQTTnet.Tests | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess() | |||||
public async Task MqttSubscriptionsManager_SubscribeTwoTimesSuccess() | |||||
{ | { | ||||
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var sp = new MqttSubscribePacket(); | var sp = new MqttSubscribePacket(); | ||||
sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); | sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); | ||||
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); | sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); | ||||
sm.SubscribeAsync(sp).GetAwaiter().GetResult(); | |||||
await sm.SubscribeAsync(sp, new MqttConnectPacket()); | |||||
var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); | var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); | ||||
Assert.IsTrue(result.IsSubscribed); | Assert.IsTrue(result.IsSubscribed); | ||||
@@ -56,33 +67,39 @@ namespace MQTTnet.Tests | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() | |||||
public async Task MqttSubscriptionsManager_SubscribeSingleNoSuccess() | |||||
{ | { | ||||
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var sp = new MqttSubscribePacket(); | var sp = new MqttSubscribePacket(); | ||||
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); | sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); | ||||
sm.SubscribeAsync(sp).GetAwaiter().GetResult(); | |||||
await sm.SubscribeAsync(sp, new MqttConnectPacket()); | |||||
Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | ||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() | |||||
public async Task MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() | |||||
{ | { | ||||
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(), | |||||
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); | |||||
var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); | |||||
var sp = new MqttSubscribePacket(); | var sp = new MqttSubscribePacket(); | ||||
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); | sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); | ||||
sm.SubscribeAsync(sp).GetAwaiter().GetResult(); | |||||
await sm.SubscribeAsync(sp, new MqttConnectPacket()); | |||||
Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | ||||
var up = new MqttUnsubscribePacket(); | var up = new MqttUnsubscribePacket(); | ||||
up.TopicFilters.Add("A/B/C"); | up.TopicFilters.Add("A/B/C"); | ||||
sm.UnsubscribeAsync(up); | |||||
await sm.UnsubscribeAsync(up); | |||||
Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); | ||||
} | } | ||||