Browse Source

Expose connect packet in application message interceptor and subscription interceptor.

release/3.x.x
Christian Kratky 5 years ago
parent
commit
3eb5e82d10
11 changed files with 179 additions and 103 deletions
  1. +2
    -0
      Build/MQTTnet.nuspec
  2. +9
    -3
      Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs
  3. +6
    -3
      Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs
  4. +54
    -0
      Source/MQTTnet/Server/MqttBaseInterceptorContext.cs
  5. +9
    -8
      Source/MQTTnet/Server/MqttClientConnection.cs
  6. +9
    -3
      Source/MQTTnet/Server/MqttClientSession.cs
  7. +25
    -13
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  8. +21
    -16
      Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs
  9. +5
    -37
      Source/MQTTnet/Server/MqttConnectionValidatorContext.cs
  10. +5
    -3
      Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs
  11. +34
    -17
      Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs

+ 2
- 0
Build/MQTTnet.nuspec View File

@@ -11,6 +11,8 @@
<requireLicenseAcceptance>false</requireLicenseAcceptance>
<description>MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker) and supports v3.1.0, v3.1.1 and v5.0.0 of the MQTT protocol.</description>
<releaseNotes>
* [Server] Added items dictionary to client session in order to share data across interceptors as along as the session exists.
* [Server] Exposed CONNECT packet properties in Application Message and Subscription interceptor.
* [MQTTnet.Server] Added REST API for publishing basic messages.
</releaseNotes>
<copyright>Copyright Christian Kratky 2016-2019</copyright>


+ 9
- 3
Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs View File

@@ -24,12 +24,18 @@ namespace MQTTnet.Server.Mqtt
{
var pythonContext = new PythonDictionary
{
{ "client_id", context.ClientId },
{ "retain", context.ApplicationMessage.Retain },
{ "username", context.Username },
{ "password", context.Password },
{ "raw_password", new Bytes(context.RawPassword ?? new byte[0]) },
{ "clean_session", context.CleanSession},
{ "authentication_method", context.AuthenticationMethod},
{ "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) },
{ "accept_publish", context.AcceptPublish },
{ "close_connection", context.CloseConnection },
{ "client_id", context.ClientId },
{ "topic", context.ApplicationMessage.Topic },
{ "qos", (int)context.ApplicationMessage.QualityOfServiceLevel },
{ "retain", context.ApplicationMessage.Retain }
{ "qos", (int)context.ApplicationMessage.QualityOfServiceLevel }
};
_pythonScriptHostService.InvokeOptionalFunction("on_intercept_application_message", pythonContext);


+ 6
- 3
Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs View File

@@ -1,8 +1,11 @@
namespace MQTTnet.Server
using System.Collections.Generic;
using MQTTnet.Packets;

namespace MQTTnet.Server
{
public class MqttApplicationMessageInterceptorContext
public class MqttApplicationMessageInterceptorContext : MqttBaseInterceptorContext
{
public MqttApplicationMessageInterceptorContext(string clientId, MqttApplicationMessage applicationMessage)
public MqttApplicationMessageInterceptorContext(string clientId, IDictionary<object, object> sessionItems, MqttConnectPacket connectPacket, MqttApplicationMessage applicationMessage) : base(connectPacket, sessionItems)
{
ClientId = clientId;
ApplicationMessage = applicationMessage;


+ 54
- 0
Source/MQTTnet/Server/MqttBaseInterceptorContext.cs View File

@@ -0,0 +1,54 @@
using System.Collections.Generic;
using System.Text;
using MQTTnet.Packets;

namespace MQTTnet.Server
{
public class MqttBaseInterceptorContext
{
private readonly MqttConnectPacket _connectPacket;

protected MqttBaseInterceptorContext(MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems)
{
_connectPacket = connectPacket;
SessionItems = sessionItems;
}

public string Username => _connectPacket?.Username;

public byte[] RawPassword => _connectPacket?.Password;

public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]);

public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage;

public bool? CleanSession => _connectPacket?.CleanSession;

public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod;

public List<MqttUserProperty> UserProperties => _connectPacket?.Properties?.UserProperties;

public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData;

public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod;

public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize;

public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum;

public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum;

public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation;

public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation;

public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval;

public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval;

/// <summary>
/// Gets or sets a key/value collection that can be used to share data within the scope of this session.
/// </summary>
public IDictionary<object, object> SessionItems { get; }
}
}

+ 9
- 8
Source/MQTTnet/Server/MqttClientConnection.cs View File

@@ -31,7 +31,6 @@ namespace MQTTnet.Server
private readonly IMqttChannelAdapter _channelAdapter;
private readonly IMqttDataConverter _dataConverter;
private readonly string _endpoint;
private readonly MqttConnectPacket _connectPacket;
private readonly DateTime _connectedTimestamp;

private Task<MqttClientDisconnectType> _packageReceiverTask;
@@ -60,22 +59,24 @@ namespace MQTTnet.Server
_channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter));
_dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter;
_endpoint = _channelAdapter.Endpoint;
_connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket));
ConnectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket));

if (logger == null) throw new ArgumentNullException(nameof(logger));
_logger = logger.CreateChildLogger(nameof(MqttClientConnection));

_keepAliveMonitor = new MqttClientKeepAliveMonitor(_connectPacket.ClientId, StopAsync, _logger);
_keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, StopAsync, _logger);

_connectedTimestamp = DateTime.UtcNow;
_lastPacketReceivedTimestamp = _connectedTimestamp;
_lastNonKeepAlivePacketReceivedTimestamp = _lastPacketReceivedTimestamp;
}

public string ClientId => _connectPacket.ClientId;
public MqttConnectPacket ConnectPacket { get; }

public MqttClientSession Session { get; }
public string ClientId => ConnectPacket.ClientId;

public MqttClientSession Session { get; }
public async Task StopAsync()
{
StopInternal();
@@ -133,12 +134,12 @@ namespace MQTTnet.Server
_channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted;
_channelAdapter.ReadingPacketCompletedCallback = OnAdapterReadingPacketCompleted;

Session.WillMessage = _connectPacket.WillMessage;
Session.WillMessage = ConnectPacket.WillMessage;

Task.Run(() => SendPendingPacketsAsync(_cancellationToken.Token), _cancellationToken.Token).Forget(_logger);

// TODO: Change to single thread in SessionManager. Or use SessionManager and stats from KeepAliveMonitor.
_keepAliveMonitor.Start(_connectPacket.KeepAlivePeriod, _cancellationToken.Token);
_keepAliveMonitor.Start(ConnectPacket.KeepAlivePeriod, _cancellationToken.Token);

await SendAsync(
new MqttConnAckPacket
@@ -271,7 +272,7 @@ namespace MQTTnet.Server
private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket)
{
// TODO: Let the channel adapter create the packet.
var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false);
var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false);

await SendAsync(subscribeResult.ResponsePacket).ConfigureAwait(false);



+ 9
- 3
Source/MQTTnet/Server/MqttClientSession.cs View File

@@ -12,11 +12,12 @@ namespace MQTTnet.Server

private readonly DateTime _createdTimestamp = DateTime.UtcNow;

public MqttClientSession(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger)
public MqttClientSession(string clientId, IDictionary<object, object> items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger)
{
ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
Items = items ?? throw new ArgumentNullException(nameof(items));

SubscriptionsManager = new MqttClientSubscriptionsManager(clientId, eventDispatcher, serverOptions);
SubscriptionsManager = new MqttClientSubscriptionsManager(this, eventDispatcher, serverOptions);
ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions);

if (logger == null) throw new ArgumentNullException(nameof(logger));
@@ -33,6 +34,11 @@ namespace MQTTnet.Server

public MqttClientSessionApplicationMessagesQueue ApplicationMessagesQueue { get; }

/// <summary>
/// Gets or sets a key/value collection that can be used to share data within the scope of this session.
/// </summary>
public IDictionary<object, object> Items { get; }

public void EnqueueApplicationMessage(MqttApplicationMessage applicationMessage, string senderClientId, bool isRetainedApplicationMessage)
{
var checkSubscriptionsResult = SubscriptionsManager.CheckSubscriptions(applicationMessage.Topic, applicationMessage.QualityOfServiceLevel);
@@ -48,7 +54,7 @@ namespace MQTTnet.Server

public async Task SubscribeAsync(ICollection<TopicFilter> topicFilters, MqttRetainedMessagesManager retainedMessagesManager)
{
await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false);
await SubscriptionsManager.SubscribeAsync(topicFilters, null).ConfigureAwait(false);

var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false);
foreach (var matchingRetainedMessage in matchingRetainedMessages)


+ 25
- 13
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -20,7 +20,8 @@ namespace MQTTnet.Server
private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1);
private readonly ConcurrentDictionary<string, MqttClientConnection> _connections = new ConcurrentDictionary<string, MqttClientConnection>();
private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>();
private readonly IDictionary<object, object> _serverSessionItems = new ConcurrentDictionary<object, object>();

private readonly CancellationToken _cancellationToken;
private readonly MqttServerEventDispatcher _eventDispatcher;

@@ -241,19 +242,19 @@ namespace MQTTnet.Server

clientId = connectPacket.ClientId;

var validatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);
var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);

if (validatorContext.ReasonCode != MqttConnectReasonCode.Success)
if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success)
{
// Send failure response here without preparing a session. The result for a successful connect
// will be sent from the session itself.
var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(validatorContext);
var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext);
await channelAdapter.SendPacketAsync(connAckPacket, _options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);

return;
}

var connection = await CreateConnectionAsync(channelAdapter, connectPacket).ConfigureAwait(false);
var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false);

await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false);
@@ -302,8 +303,7 @@ namespace MQTTnet.Server
await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false);

// Check the client ID and set a random one if supported.
if (string.IsNullOrEmpty(connectPacket.ClientId) &&
channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500)
if (string.IsNullOrEmpty(connectPacket.ClientId) && channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500)
{
connectPacket.ClientId = context.AssignedClientIdentifier;
}
@@ -316,7 +316,7 @@ namespace MQTTnet.Server
return context;
}

private async Task<MqttClientConnection> CreateConnectionAsync(IMqttChannelAdapter channelAdapter, MqttConnectPacket connectPacket)
private async Task<MqttClientConnection> CreateConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter)
{
await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false);
try
@@ -345,7 +345,7 @@ namespace MQTTnet.Server

if (session == null)
{
session = new MqttClientSession(connectPacket.ClientId, _eventDispatcher, _options, _logger);
session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _logger);
_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId);
}

@@ -362,7 +362,7 @@ namespace MQTTnet.Server
}
}

private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection sender, MqttApplicationMessage applicationMessage)
private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage)
{
var interceptor = _options.ApplicationMessageInterceptor;
if (interceptor == null)
@@ -370,13 +370,25 @@ namespace MQTTnet.Server
return null;
}

var senderClientId = sender?.ClientId;
if (sender == null)
string senderClientId;
IDictionary<object, object> sessionItems;
MqttConnectPacket connectPacket;

var messageIsFromServer = senderConnection == null;
if (messageIsFromServer)
{
senderClientId = _options.ClientId;
sessionItems = _serverSessionItems;
connectPacket = null;
}
else
{
senderClientId = senderConnection.ClientId;
sessionItems = senderConnection.Session.Items;
connectPacket = senderConnection.ConnectPacket;
}

var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, applicationMessage);
var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, connectPacket, applicationMessage);
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
return interceptorContext;
}


+ 21
- 16
Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs View File

@@ -10,20 +10,23 @@ namespace MQTTnet.Server
public class MqttClientSubscriptionsManager
{
private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
private readonly IMqttServerOptions _options;
private readonly MqttClientSession _clientSession;
private readonly IMqttServerOptions _serverOptions;
private readonly MqttServerEventDispatcher _eventDispatcher;
private readonly string _clientId;

public MqttClientSubscriptionsManager(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions options)
public MqttClientSubscriptionsManager(MqttClientSession clientSession, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions)
{
_clientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
_options = options ?? throw new ArgumentNullException(nameof(options));
_clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession));

// TODO: Consider removing the server options here and build a new class "ISubscriptionInterceptor" and just pass it. The instance is generated in the root server class upon start.
_serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions));
_eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher));
}

public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket)
public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket, MqttConnectPacket connectPacket)
{
if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));
if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket));

var result = new MqttClientSubscribeResult
{
@@ -37,7 +40,7 @@ namespace MQTTnet.Server

foreach (var originalTopicFilter in subscribePacket.TopicFilters)
{
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter).ConfigureAwait(false);
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter, connectPacket).ConfigureAwait(false);

var finalTopicFilter = interceptorContext.TopicFilter;

@@ -64,18 +67,20 @@ namespace MQTTnet.Server
_subscriptions[finalTopicFilter.Topic] = finalTopicFilter.QualityOfServiceLevel;
}

await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, finalTopicFilter).ConfigureAwait(false);
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false);
}
}

return result;
}

public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters)
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters, MqttConnectPacket connectPacket)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

foreach (var topicFilter in topicFilters)
{
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false);
var interceptorContext = await InterceptSubscribeAsync(topicFilter, connectPacket).ConfigureAwait(false);
if (!interceptorContext.AcceptSubscription)
{
continue;
@@ -88,7 +93,7 @@ namespace MQTTnet.Server
_subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
}

await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false);
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
}
}
}
@@ -119,7 +124,7 @@ namespace MQTTnet.Server

foreach (var topicFilter in unsubscribePacket.TopicFilters)
{
await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false);
await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
}
return unsubAckPacket;
@@ -190,12 +195,12 @@ namespace MQTTnet.Server
}
}

private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter)
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter, MqttConnectPacket connectPacket)
{
var context = new MqttSubscriptionInterceptorContext(_clientId, topicFilter);
if (_options.SubscriptionInterceptor != null)
var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, connectPacket, _clientSession.Items);
if (_serverOptions.SubscriptionInterceptor != null)
{
await _options.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false);
await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false);
}

return context;


+ 5
- 37
Source/MQTTnet/Server/MqttConnectionValidatorContext.cs View File

@@ -1,7 +1,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using MQTTnet.Adapter;
using MQTTnet.Formatter;
using MQTTnet.Packets;
@@ -9,51 +9,19 @@ using MQTTnet.Protocol;

namespace MQTTnet.Server
{
public class MqttConnectionValidatorContext
public class MqttConnectionValidatorContext : MqttBaseInterceptorContext
{
private readonly MqttConnectPacket _connectPacket;
private readonly IMqttChannelAdapter _clientAdapter;

public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter)
public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) : base(connectPacket, new ConcurrentDictionary<object, object>())
{
_connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket));
_connectPacket = connectPacket;
_clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter));
}

public string ClientId => _connectPacket.ClientId;

public string Username => _connectPacket.Username;

public byte[] RawPassword => _connectPacket.Password;

public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]);

public MqttApplicationMessage WillMessage => _connectPacket.WillMessage;

public bool CleanSession => _connectPacket.CleanSession;

public ushort KeepAlivePeriod => _connectPacket.KeepAlivePeriod;

public List<MqttUserProperty> UserProperties => _connectPacket.Properties?.UserProperties;

public byte[] AuthenticationData => _connectPacket.Properties?.AuthenticationData;

public string AuthenticationMethod => _connectPacket.Properties?.AuthenticationMethod;

public uint? MaximumPacketSize => _connectPacket.Properties?.MaximumPacketSize;

public ushort? ReceiveMaximum => _connectPacket.Properties?.ReceiveMaximum;

public ushort? TopicAliasMaximum => _connectPacket.Properties?.TopicAliasMaximum;

public bool? RequestProblemInformation => _connectPacket.Properties?.RequestProblemInformation;

public bool? RequestResponseInformation => _connectPacket.Properties?.RequestResponseInformation;

public uint? SessionExpiryInterval => _connectPacket.Properties?.SessionExpiryInterval;

public uint? WillDelayInterval => _connectPacket.Properties?.WillDelayInterval;

public string Endpoint => _clientAdapter.Endpoint;

public bool IsSecureConnection => _clientAdapter.IsSecureConnection;
@@ -61,7 +29,7 @@ namespace MQTTnet.Server
public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate;

public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion;
/// <summary>
/// This is used for MQTTv3 only.
/// </summary>


+ 5
- 3
Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs View File

@@ -1,10 +1,12 @@
using System;
using System.Collections.Generic;
using MQTTnet.Packets;

namespace MQTTnet.Server
{
public class MqttSubscriptionInterceptorContext
public class MqttSubscriptionInterceptorContext : MqttBaseInterceptorContext
{
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter)
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems) : base(connectPacket, sessionItems)
{
ClientId = clientId;
TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter));
@@ -13,7 +15,7 @@ namespace MQTTnet.Server
public string ClientId { get; }

public TopicFilter TopicFilter { get; set; }
public bool AcceptSubscription { get; set; } = true;

public bool CloseConnection { get; set; }


+ 34
- 17
Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs View File

@@ -1,4 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Concurrent;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using MQTTnet.Server;
@@ -10,14 +12,17 @@ namespace MQTTnet.Tests
public class MqttSubscriptionsManager_Tests
{
[TestMethod]
public void MqttSubscriptionsManager_SubscribeSingleSuccess()
public async Task MqttSubscriptionsManager_SubscribeSingleSuccess()
{
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(),
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger());

var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.SubscribeAsync(sp).GetAwaiter().GetResult();
await sm.SubscribeAsync(sp, new MqttConnectPacket());

var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce);
Assert.IsTrue(result.IsSubscribed);
@@ -25,14 +30,17 @@ namespace MQTTnet.Tests
}

[TestMethod]
public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess()
public async Task MqttSubscriptionsManager_SubscribeDifferentQoSSuccess()
{
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(),
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger());

var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });

sm.SubscribeAsync(sp).GetAwaiter().GetResult();
await sm.SubscribeAsync(sp, new MqttConnectPacket());

var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed);
@@ -40,15 +48,18 @@ namespace MQTTnet.Tests
}

[TestMethod]
public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess()
public async Task MqttSubscriptionsManager_SubscribeTwoTimesSuccess()
{
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(),
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger());

var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce });

sm.SubscribeAsync(sp).GetAwaiter().GetResult();
await sm.SubscribeAsync(sp, new MqttConnectPacket());

var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed);
@@ -56,33 +67,39 @@ namespace MQTTnet.Tests
}

[TestMethod]
public void MqttSubscriptionsManager_SubscribeSingleNoSuccess()
public async Task MqttSubscriptionsManager_SubscribeSingleNoSuccess()
{
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(),
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger());

var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.SubscribeAsync(sp).GetAwaiter().GetResult();
await sm.SubscribeAsync(sp, new MqttConnectPacket());

Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);
}

[TestMethod]
public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle()
public async Task MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle()
{
var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());
var s = new MqttClientSession("", new ConcurrentDictionary<object, object>(),
new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger());

var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions());

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.SubscribeAsync(sp).GetAwaiter().GetResult();
await sm.SubscribeAsync(sp, new MqttConnectPacket());

Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);

var up = new MqttUnsubscribePacket();
up.TopicFilters.Add("A/B/C");
sm.UnsubscribeAsync(up);
await sm.UnsubscribeAsync(up);

Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);
}


Loading…
Cancel
Save