Browse Source

Improve session handling.

release/3.x.x
Christian Kratky 5 years ago
parent
commit
96a67579e2
13 changed files with 477 additions and 218 deletions
  1. +8
    -6
      Source/MQTTnet/Client/MqttClient.cs
  2. +2
    -3
      Source/MQTTnet/Client/MqttClientExtensions.cs
  3. +1
    -1
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  4. +5
    -1
      Source/MQTTnet/Server/IMqttClientSessionStatus.cs
  5. +2
    -2
      Source/MQTTnet/Server/IMqttServer.cs
  6. +92
    -78
      Source/MQTTnet/Server/MqttClientSession.cs
  7. +19
    -13
      Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs
  8. +3
    -1
      Source/MQTTnet/Server/MqttClientSessionStatus.cs
  9. +65
    -50
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  10. +3
    -3
      Source/MQTTnet/Server/MqttServer.cs
  11. +45
    -0
      Source/MQTTnet/Server/MqttServerExtensions.cs
  12. +140
    -60
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs
  13. +92
    -0
      Tests/MQTTnet.Core.Tests/TestSetup.cs

+ 8
- 6
Source/MQTTnet/Client/MqttClient.cs View File

@@ -66,24 +66,26 @@ namespace MQTTnet.Client
_packetIdentifierProvider.Reset(); _packetIdentifierProvider.Reset();
_packetDispatcher.Reset(); _packetDispatcher.Reset();


_cancellationTokenSource = new CancellationTokenSource();
var cancellationTokenSource = new CancellationTokenSource();
_cancellationTokenSource = cancellationTokenSource;

_disconnectGate = 0; _disconnectGate = 0;
_adapter = _adapterFactory.CreateClientAdapter(options, _logger); _adapter = _adapterFactory.CreateClientAdapter(options, _logger);


_logger.Verbose($"Trying to connect with server ({Options.ChannelOptions})."); _logger.Verbose($"Trying to connect with server ({Options.ChannelOptions}).");
await _adapter.ConnectAsync(Options.CommunicationTimeout, _cancellationTokenSource.Token).ConfigureAwait(false);
await _adapter.ConnectAsync(Options.CommunicationTimeout, cancellationTokenSource.Token).ConfigureAwait(false);
_logger.Verbose("Connection with server established."); _logger.Verbose("Connection with server established.");


StartReceivingPackets(_cancellationTokenSource.Token);
StartReceivingPackets(cancellationTokenSource.Token);


var connectResult = await AuthenticateAsync(options.WillMessage, _cancellationTokenSource.Token).ConfigureAwait(false);
var connectResult = await AuthenticateAsync(options.WillMessage, cancellationTokenSource.Token).ConfigureAwait(false);
_logger.Verbose("MQTT connection with server established."); _logger.Verbose("MQTT connection with server established.");


_sendTracker.Restart(); _sendTracker.Restart();


if (Options.KeepAlivePeriod != TimeSpan.Zero) if (Options.KeepAlivePeriod != TimeSpan.Zero)
{ {
StartSendingKeepAliveMessages(_cancellationTokenSource.Token);
StartSendingKeepAliveMessages(cancellationTokenSource.Token);
} }


IsConnected = true; IsConnected = true;
@@ -112,7 +114,7 @@ namespace MQTTnet.Client
{ {
_cleanDisconnectInitiated = true; _cleanDisconnectInitiated = true;


if (IsConnected && !_cancellationTokenSource.IsCancellationRequested)
if (IsConnected && _cancellationTokenSource?.IsCancellationRequested == false)
{ {
var disconnectPacket = CreateDisconnectPacket(options); var disconnectPacket = CreateDisconnectPacket(options);
await SendAsync(disconnectPacket, _cancellationTokenSource.Token).ConfigureAwait(false); await SendAsync(disconnectPacket, _cancellationTokenSource.Token).ConfigureAwait(false);


+ 2
- 3
Source/MQTTnet/Client/MqttClientExtensions.cs View File

@@ -1,5 +1,4 @@
using System; using System;
using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Client.Subscribing; using MQTTnet.Client.Subscribing;
using MQTTnet.Client.Unsubscribing; using MQTTnet.Client.Unsubscribing;
@@ -21,7 +20,7 @@ namespace MQTTnet.Client
if (client == null) throw new ArgumentNullException(nameof(client)); if (client == null) throw new ArgumentNullException(nameof(client));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));


return client.SubscribeAsync(topicFilters.ToList());
return client.SubscribeAsync(topicFilters);
} }


public static Task<MqttClientSubscribeResult> SubscribeAsync(this IMqttClient client, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel) public static Task<MqttClientSubscribeResult> SubscribeAsync(this IMqttClient client, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel)
@@ -45,7 +44,7 @@ namespace MQTTnet.Client
if (client == null) throw new ArgumentNullException(nameof(client)); if (client == null) throw new ArgumentNullException(nameof(client));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));


return client.UnsubscribeAsync(topicFilters.ToList());
return client.UnsubscribeAsync(topicFilters);
} }
} }
} }

+ 1
- 1
Source/MQTTnet/Implementations/MqttTcpChannel.cs View File

@@ -93,7 +93,7 @@ namespace MQTTnet.Implementations
using (cancellationToken.Register(() => _socket.Dispose())) using (cancellationToken.Register(() => _socket.Dispose()))
{ {
await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
await _stream.FlushAsync(cancellationToken);
//await _stream.FlushAsync(cancellationToken);
} }
} }




+ 5
- 1
Source/MQTTnet/Server/IMqttClientSessionStatus.cs View File

@@ -18,7 +18,11 @@ namespace MQTTnet.Server


TimeSpan LastNonKeepAlivePacketReceived { get; } TimeSpan LastNonKeepAlivePacketReceived { get; }


int PendingApplicationMessagesCount { get; }
long PendingApplicationMessagesCount { get; }

long ReceivedApplicationMessagesCount { get; }

long SentApplicationMessagesCount { get; }


Task DisconnectAsync(); Task DisconnectAsync();




+ 2
- 2
Source/MQTTnet/Server/IMqttServer.cs View File

@@ -21,8 +21,8 @@ namespace MQTTnet.Server
IList<MqttApplicationMessage> GetRetainedMessages(); IList<MqttApplicationMessage> GetRetainedMessages();
Task ClearRetainedMessagesAsync(); Task ClearRetainedMessagesAsync();


Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters);
Task UnsubscribeAsync(string clientId, IList<string> topicFilters);
Task SubscribeAsync(string clientId, IEnumerable<TopicFilter> topicFilters);
Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters);


Task StartAsync(IMqttServerOptions options); Task StartAsync(IMqttServerOptions options);
Task StopAsync(); Task StopAsync();


+ 92
- 78
Source/MQTTnet/Server/MqttClientSession.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Adapter; using MQTTnet.Adapter;
@@ -18,7 +19,7 @@ namespace MQTTnet.Server
private readonly MqttRetainedMessagesManager _retainedMessagesManager; private readonly MqttRetainedMessagesManager _retainedMessagesManager;
private readonly MqttServerEventDispatcher _eventDispatcher; private readonly MqttServerEventDispatcher _eventDispatcher;
private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; private readonly MqttClientKeepAliveMonitor _keepAliveMonitor;
private readonly MqttClientPendingPacketsQueue _pendingPacketsQueue;
private readonly MqttClientSessionPendingMessagesQueue _pendingMessagesQueue;
private readonly MqttClientSubscriptionsManager _subscriptionsManager; private readonly MqttClientSubscriptionsManager _subscriptionsManager;
private readonly MqttClientSessionsManager _sessionsManager; private readonly MqttClientSessionsManager _sessionsManager;


@@ -31,6 +32,9 @@ namespace MQTTnet.Server
private Task _workerTask; private Task _workerTask;
private IMqttChannelAdapter _channelAdapter; private IMqttChannelAdapter _channelAdapter;


private long _receivedMessagesCount;
private bool _isCleanSession = true;

public MqttClientSession( public MqttClientSession(
string clientId, string clientId,
IMqttServerOptions options, IMqttServerOptions options,
@@ -52,7 +56,7 @@ namespace MQTTnet.Server


_keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger); _keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger);
_subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, eventDispatcher); _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, eventDispatcher);
_pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger);
_pendingMessagesQueue = new MqttClientSessionPendingMessagesQueue(_options, this, _logger);
} }


public string ClientId { get; } public string ClientId { get; }
@@ -63,28 +67,38 @@ namespace MQTTnet.Server
status.IsConnected = _cancellationTokenSource != null; status.IsConnected = _cancellationTokenSource != null;
status.Endpoint = _channelAdapter?.Endpoint; status.Endpoint = _channelAdapter?.Endpoint;
status.ProtocolVersion = _channelAdapter?.PacketFormatterAdapter?.ProtocolVersion; status.ProtocolVersion = _channelAdapter?.PacketFormatterAdapter?.ProtocolVersion;
status.PendingApplicationMessagesCount = _pendingPacketsQueue.Count;
status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count;
status.ReceivedApplicationMessagesCount = _pendingMessagesQueue.SentMessagesCount;
status.SentApplicationMessagesCount = Interlocked.Read(ref _receivedMessagesCount);
status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived; status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived;
status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived;
} }


public Task StopAsync(MqttClientDisconnectType type)
public async Task StopAsync(MqttClientDisconnectType type)
{ {
return StopAsync(type, false);
StopInternal(type);

var task = _workerTask;
if (task != null && !task.IsCompleted)
{
await task.ConfigureAwait(false);
}
} }


public async Task SubscribeAsync(IList<TopicFilter> topicFilters)
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters)
{ {
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));


var topicFiltersCollection = topicFilters.ToList();

var packet = new MqttSubscribePacket(); var packet = new MqttSubscribePacket();
packet.TopicFilters.AddRange(topicFilters);
packet.TopicFilters.AddRange(topicFiltersCollection);


await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false); await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false);
await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false);
await EnqueueSubscribedRetainedMessagesAsync(topicFiltersCollection).ConfigureAwait(false);
} }


public Task UnsubscribeAsync(IList<string> topicFilters)
public Task UnsubscribeAsync(IEnumerable<string> topicFilters)
{ {
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));


@@ -98,12 +112,12 @@ namespace MQTTnet.Server


public void ClearPendingApplicationMessages() public void ClearPendingApplicationMessages()
{ {
_pendingPacketsQueue.Clear();
_pendingMessagesQueue.Clear();
} }


public void Dispose() public void Dispose()
{ {
_pendingPacketsQueue?.Dispose();
_pendingMessagesQueue?.Dispose();


_cancellationTokenSource?.Cancel(); _cancellationTokenSource?.Cancel();
_cancellationTokenSource?.Dispose(); _cancellationTokenSource?.Dispose();
@@ -161,7 +175,7 @@ namespace MQTTnet.Server
publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel; publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel;
} }


_pendingPacketsQueue.Enqueue(publishPacket);
_pendingMessagesQueue.Enqueue(publishPacket);
} }


private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
@@ -170,26 +184,41 @@ namespace MQTTnet.Server


try try
{ {
channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted;
channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted;
_logger.Info("Client '{0}': Connected.", ClientId);
_eventDispatcher.OnClientConnected(ClientId);

_channelAdapter = channelAdapter;

_channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted;
_channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted;


_cancellationTokenSource = new CancellationTokenSource();
var cancellationTokenSource = new CancellationTokenSource();
_cancellationTokenSource = cancellationTokenSource;


_wasCleanDisconnect = false; _wasCleanDisconnect = false;
_willMessage = connectPacket.WillMessage; _willMessage = connectPacket.WillMessage;


_pendingPacketsQueue.Start(channelAdapter, _cancellationTokenSource.Token);
_keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token);
_pendingMessagesQueue.Start(channelAdapter, cancellationTokenSource.Token);
_keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, cancellationTokenSource.Token);


_channelAdapter = channelAdapter;
await channelAdapter.SendPacketAsync(
new MqttConnAckPacket
{
ReturnCode = MqttConnectReturnCode.ConnectionAccepted,
ReasonCode = MqttConnectReasonCode.Success,
IsSessionPresent = _isCleanSession
},
cancellationTokenSource.Token).ConfigureAwait(false);


while (!_cancellationTokenSource.IsCancellationRequested)
_isCleanSession = false;

while (!cancellationTokenSource.IsCancellationRequested)
{ {
var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false);
var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, cancellationTokenSource.Token).ConfigureAwait(false);
if (packet != null) if (packet != null)
{ {
_keepAliveMonitor.PacketReceived(packet); _keepAliveMonitor.PacketReceived(packet);
await ProcessReceivedPacketAsync(channelAdapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false);
await ProcessReceivedPacketAsync(channelAdapter, packet, cancellationTokenSource.Token).ConfigureAwait(false);
} }
} }
} }
@@ -203,6 +232,9 @@ namespace MQTTnet.Server
if (exception is MqttCommunicationClosedGracefullyException) if (exception is MqttCommunicationClosedGracefullyException)
{ {
_logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId); _logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId);

StopInternal(MqttClientDisconnectType.Clean);
return;
} }
else else
{ {
@@ -214,69 +246,50 @@ namespace MQTTnet.Server
_logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId);
} }


await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
StopInternal(MqttClientDisconnectType.NotClean);
} }
finally finally
{ {
channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;

_cancellationTokenSource?.Cancel(false);
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;

_workerTask = null;
}
}

private async Task StopAsync(MqttClientDisconnectType type, bool isInsideSession)
{
try
{
var cts = _cancellationTokenSource;
if (cts == null || cts.IsCancellationRequested)
{
return;
}

_wasCleanDisconnect = type == MqttClientDisconnectType.Clean;

_cancellationTokenSource?.Cancel(false);

if (_willMessage != null && !_wasCleanDisconnect) if (_willMessage != null && !_wasCleanDisconnect)
{ {
_sessionsManager.EnqueueApplicationMessage(this, _willMessage); _sessionsManager.EnqueueApplicationMessage(this, _willMessage);
} }


_willMessage = null; _willMessage = null;

if (!isInsideSession)
{
if (_workerTask != null)
{
await _workerTask.ConfigureAwait(false);
}
}
await Task.FromResult(0);
}
finally
{
_channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
_channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
_channelAdapter = null;
_logger.Info("Client '{0}': Session stopped.", ClientId); _logger.Info("Client '{0}': Session stopped.", ClientId);
_eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect); _eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect);

_workerTask = null;
} }
} }
private void StopInternal(MqttClientDisconnectType type)
{
var cts = _cancellationTokenSource;
if (cts == null || cts.IsCancellationRequested)
{
return;
}

_wasCleanDisconnect = type == MqttClientDisconnectType.Clean;
_cancellationTokenSource?.Cancel(false);
}


private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
private Task ProcessReceivedPacketAsync(IMqttChannelAdapter channelAdapter, MqttBasePacket packet, CancellationToken cancellationToken)
{ {
if (packet is MqttPublishPacket publishPacket) if (packet is MqttPublishPacket publishPacket)
{ {
return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken);
return HandleIncomingPublishPacketAsync(channelAdapter, publishPacket, cancellationToken);
} }


if (packet is MqttPingReqPacket) if (packet is MqttPingReqPacket)
{ {
return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
return channelAdapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
} }


if (packet is MqttPubRelPacket pubRelPacket) if (packet is MqttPubRelPacket pubRelPacket)
@@ -287,7 +300,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubCompReasonCode.Success ReasonCode = MqttPubCompReasonCode.Success
}; };


return adapter.SendPacketAsync(responsePacket, cancellationToken);
return channelAdapter.SendPacketAsync(responsePacket, cancellationToken);
} }


if (packet is MqttPubRecPacket pubRecPacket) if (packet is MqttPubRecPacket pubRecPacket)
@@ -298,7 +311,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubRelReasonCode.Success ReasonCode = MqttPubRelReasonCode.Success
}; };


return adapter.SendPacketAsync(responsePacket, cancellationToken);
return channelAdapter.SendPacketAsync(responsePacket, cancellationToken);
} }


if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
@@ -308,27 +321,24 @@ namespace MQTTnet.Server


if (packet is MqttSubscribePacket subscribePacket) if (packet is MqttSubscribePacket subscribePacket)
{ {
return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken);
return HandleIncomingSubscribePacketAsync(channelAdapter, subscribePacket, cancellationToken);
} }


if (packet is MqttUnsubscribePacket unsubscribePacket) if (packet is MqttUnsubscribePacket unsubscribePacket)
{ {
return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken);
return HandleIncomingUnsubscribePacketAsync(channelAdapter, unsubscribePacket, cancellationToken);
} }


if (packet is MqttDisconnectPacket) if (packet is MqttDisconnectPacket)
{ {
return StopAsync(MqttClientDisconnectType.Clean, true);
}

if (packet is MqttConnectPacket)
{
return StopAsync(MqttClientDisconnectType.NotClean, true);
StopInternal(MqttClientDisconnectType.Clean);
return Task.FromResult(0);
} }


_logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet);
_logger.Warning(null, "Client '{0}': Received invalid packet ({1}). Closing connection.", ClientId, packet);


return StopAsync(MqttClientDisconnectType.NotClean, true);
StopInternal(MqttClientDisconnectType.NotClean);
return Task.FromResult(0);
} }


private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters) private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
@@ -347,7 +357,8 @@ namespace MQTTnet.Server


if (subscribeResult.CloseConnection) if (subscribeResult.CloseConnection)
{ {
await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
StopInternal(MqttClientDisconnectType.NotClean);
return;
} }


await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false);
@@ -361,12 +372,13 @@ namespace MQTTnet.Server


private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{ {
Interlocked.Increment(ref _receivedMessagesCount);

switch (publishPacket.QualityOfServiceLevel) switch (publishPacket.QualityOfServiceLevel)
{ {
case MqttQualityOfServiceLevel.AtMostOnce: case MqttQualityOfServiceLevel.AtMostOnce:
{ {
HandleIncomingPublishPacketWithQoS0(publishPacket);
return Task.FromResult(0);
return HandleIncomingPublishPacketWithQoS0Async(publishPacket);
} }
case MqttQualityOfServiceLevel.AtLeastOnce: case MqttQualityOfServiceLevel.AtLeastOnce:
{ {
@@ -383,11 +395,13 @@ namespace MQTTnet.Server
} }
} }


private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket)
private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket)
{ {
_sessionsManager.EnqueueApplicationMessage( _sessionsManager.EnqueueApplicationMessage(
this, this,
_channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket));

return Task.FromResult(0);
} }


private Task HandleIncomingPublishPacketWithQoS1Async( private Task HandleIncomingPublishPacketWithQoS1Async(


Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs → Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs View File

@@ -11,22 +11,24 @@ using MQTTnet.Protocol;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
public class MqttClientPendingPacketsQueue : IDisposable
public class MqttClientSessionPendingMessagesQueue : IDisposable
{ {
private readonly Queue<MqttBasePacket> _queue = new Queue<MqttBasePacket>(); private readonly Queue<MqttBasePacket> _queue = new Queue<MqttBasePacket>();
private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent();
private readonly AsyncAutoResetEvent _queueLock = new AsyncAutoResetEvent();


private readonly IMqttServerOptions _options; private readonly IMqttServerOptions _options;
private readonly MqttClientSession _clientSession; private readonly MqttClientSession _clientSession;
private readonly IMqttNetChildLogger _logger; private readonly IMqttNetChildLogger _logger;


public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger)
private long _sentPacketsCount;

public MqttClientSessionPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger)
{ {
if (logger == null) throw new ArgumentNullException(nameof(logger)); if (logger == null) throw new ArgumentNullException(nameof(logger));
_options = options ?? throw new ArgumentNullException(nameof(options)); _options = options ?? throw new ArgumentNullException(nameof(options));
_clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession));


_logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue));
_logger = logger.CreateChildLogger(nameof(MqttClientSessionPendingMessagesQueue));
} }


public int Count public int Count
@@ -40,6 +42,8 @@ namespace MQTTnet.Server
} }
} }


public long SentMessagesCount => Interlocked.Read(ref _sentPacketsCount);

public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken) public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken)
{ {
if (adapter == null) throw new ArgumentNullException(nameof(adapter)); if (adapter == null) throw new ArgumentNullException(nameof(adapter));
@@ -52,7 +56,7 @@ namespace MQTTnet.Server
Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken); Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken);
} }


public void Enqueue(MqttBasePacket packet)
public void Enqueue(MqttPublishPacket packet)
{ {
if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packet == null) throw new ArgumentNullException(nameof(packet));


@@ -70,11 +74,11 @@ namespace MQTTnet.Server
_queue.Dequeue(); _queue.Dequeue();
} }
} }
_queue.Enqueue(packet); _queue.Enqueue(packet);
} }


_queueAutoResetEvent.Set();
_queueLock.Set();


_logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId); _logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId);
} }
@@ -114,6 +118,11 @@ namespace MQTTnet.Server
MqttBasePacket packet = null; MqttBasePacket packet = null;
try try
{ {
if (cancellationToken.IsCancellationRequested)
{
return;
}

lock (_queue) lock (_queue)
{ {
if (_queue.Count > 0) if (_queue.Count > 0)
@@ -124,18 +133,15 @@ namespace MQTTnet.Server


if (packet == null) if (packet == null)
{ {
await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false);
return;
}

if (cancellationToken.IsCancellationRequested)
{
await _queueLock.WaitOneAsync(cancellationToken).ConfigureAwait(false);
return; return;
} }


await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false);


_logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId);

Interlocked.Increment(ref _sentPacketsCount);
} }
catch (Exception exception) catch (Exception exception)
{ {

+ 3
- 1
Source/MQTTnet/Server/MqttClientSessionStatus.cs View File

@@ -21,7 +21,9 @@ namespace MQTTnet.Server
public MqttProtocolVersion? ProtocolVersion { get; set; } public MqttProtocolVersion? ProtocolVersion { get; set; }
public TimeSpan LastPacketReceived { get; set; } public TimeSpan LastPacketReceived { get; set; }
public TimeSpan LastNonKeepAlivePacketReceived { get; set; } public TimeSpan LastNonKeepAlivePacketReceived { get; set; }
public int PendingApplicationMessagesCount { get; set; }
public long PendingApplicationMessagesCount { get; set; }
public long ReceivedApplicationMessagesCount { get; set; }
public long SentApplicationMessagesCount { get; set; }


public Task DisconnectAsync() public Task DisconnectAsync()
{ {


+ 65
- 50
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Adapter; using MQTTnet.Adapter;
@@ -49,27 +50,31 @@ namespace MQTTnet.Server


public async Task StopAsync() public async Task StopAsync()
{ {
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
List<MqttClientSession> sessions;
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
foreach (var session in _sessions)
{
await session.Value.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
}
sessions = _sessions.Values.ToList();
}


_sessions.Clear();
foreach (var session in sessions)
{
await session.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);
} }
} }


public Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter) public Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter)
{ {
return Task.Run(() => HandleConnectionAsync(clientAdapter, _cancellationToken), _cancellationToken);
return HandleConnectionAsync(clientAdapter, _cancellationToken);

// TODO: Check if Task.Run is required.
//return Task.Run(() => HandleConnectionAsync(clientAdapter, _cancellationToken), _cancellationToken);
} }


public async Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync() public async Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync()
{ {
var result = new List<IMqttClientSessionStatus>(); var result = new List<IMqttClientSessionStatus>();


using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
foreach (var session in _sessions.Values) foreach (var session in _sessions.Values)
{ {
@@ -90,42 +95,47 @@ namespace MQTTnet.Server
_messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken); _messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken);
} }


public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
public async Task SubscribeAsync(string clientId, IEnumerable<TopicFilter> topicFilters)
{ {
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));


lock (_sessions)
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
if (!_sessions.TryGetValue(clientId, out var session)) if (!_sessions.TryGetValue(clientId, out var session))
{ {
throw new InvalidOperationException($"Client session '{clientId}' is unknown."); throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
} }


return session.SubscribeAsync(topicFilters);
await session.SubscribeAsync(topicFilters).ConfigureAwait(false);
} }
} }


public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
public async Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters)
{ {
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));


lock (_sessions)
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
if (!_sessions.TryGetValue(clientId, out var session)) if (!_sessions.TryGetValue(clientId, out var session))
{ {
throw new InvalidOperationException($"Client session '{clientId}' is unknown."); throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
} }


return session.UnsubscribeAsync(topicFilters);
await session.UnsubscribeAsync(topicFilters).ConfigureAwait(false);
} }
} }


public async Task DeleteSessionAsync(string clientId) public async Task DeleteSessionAsync(string clientId)
{ {
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
if (_sessions.TryGetValue(clientId, out var session))
{
session.Dispose();
}

_sessions.Remove(clientId); _sessions.Remove(clientId);
} }


@@ -187,7 +197,7 @@ namespace MQTTnet.Server
await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
} }


using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
foreach (var clientSession in _sessions.Values) foreach (var clientSession in _sessions.Values)
{ {
@@ -207,49 +217,37 @@ namespace MQTTnet.Server
} }
} }


private async Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
private async Task HandleConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{ {
var clientId = string.Empty; var clientId = string.Empty;
try try
{ {
var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
if (!(firstPacket is MqttConnectPacket connectPacket)) if (!(firstPacket is MqttConnectPacket connectPacket))
{ {
_logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", clientAdapter.Endpoint);
_logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", channelAdapter.Endpoint);
return; return;
} }


clientId = connectPacket.ClientId; clientId = connectPacket.ClientId;


var connectReturnCode = await ValidateConnectionAsync(connectPacket, clientAdapter).ConfigureAwait(false);
var connectReturnCode = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);
if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
{ {
await clientAdapter.SendPacketAsync(
await channelAdapter.SendPacketAsync(
new MqttConnAckPacket new MqttConnAckPacket
{ {
ReturnCode = connectReturnCode
ReturnCode = connectReturnCode,
ReasonCode = MqttConnectReasonCode.NotAuthorized
}, },
cancellationToken).ConfigureAwait(false); cancellationToken).ConfigureAwait(false);


return; return;
} }


var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);

await clientAdapter.SendPacketAsync(
new MqttConnAckPacket
{
ReturnCode = connectReturnCode,
ReasonCode = MqttConnectReasonCode.Success,
IsSessionPresent = result.IsExistingSession
},
cancellationToken).ConfigureAwait(false);

_logger.Info("Client '{0}': Connected.", clientId);
_eventDispatcher.OnClientConnected(clientId);

await result.Session.RunAsync(connectPacket, clientAdapter).ConfigureAwait(false);
var session = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);
await session.RunAsync(connectPacket, channelAdapter).ConfigureAwait(false);
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
@@ -260,12 +258,15 @@ namespace MQTTnet.Server
} }
finally finally
{ {
await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
clientAdapter.Dispose();
await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false);


if (!_options.EnablePersistentSessions) if (!_options.EnablePersistentSessions)
{ {
await DeleteSessionAsync(clientId).ConfigureAwait(false);
// TODO: Check if the session will be used later.
// Consider reference counter or "Recycle" property
// Or add timer (will be required for MQTTv5 (session life time) "IsActiveProperty".
//öö
//await DeleteSessionAsync(clientId).ConfigureAwait(false);
} }
} }
} }
@@ -288,18 +289,17 @@ namespace MQTTnet.Server
return context.ReturnCode; return context.ReturnCode;
} }


private async Task<PrepareClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
private async Task<MqttClientSession> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
{ {
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession);
if (isSessionPresent) if (isSessionPresent)
{ {
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);

if (connectPacket.CleanSession) if (connectPacket.CleanSession)
{ {
_sessions.Remove(connectPacket.ClientId);
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);

clientSession.Dispose(); clientSession.Dispose();
clientSession = null; clientSession = null;


@@ -307,22 +307,21 @@ namespace MQTTnet.Server
} }
else else
{ {
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);

_logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId); _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId);
} }
} }


var isExistingSession = true;
if (clientSession == null) if (clientSession == null)
{ {
isExistingSession = false;

clientSession = new MqttClientSession(connectPacket.ClientId, _options, this, _retainedMessagesManager, _eventDispatcher, _logger); clientSession = new MqttClientSession(connectPacket.ClientId, _options, this, _retainedMessagesManager, _eventDispatcher, _logger);
_sessions[connectPacket.ClientId] = clientSession; _sessions[connectPacket.ClientId] = clientSession;


_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId);
} }


return new PrepareClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession };
return clientSession;
} }
} }


@@ -338,5 +337,21 @@ namespace MQTTnet.Server
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
return interceptorContext; return interceptorContext;
} }

private async Task TryCleanupChannelAsync(IMqttChannelAdapter channelAdapter)
{
try
{
await channelAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while disconnecting client channel.");
}
finally
{
channelAdapter.Dispose();
}
}
} }
} }

+ 3
- 3
Source/MQTTnet/Server/MqttServer.cs View File

@@ -56,7 +56,7 @@ namespace MQTTnet.Server
return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult(); return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult();
} }


public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
public Task SubscribeAsync(string clientId, IEnumerable<TopicFilter> topicFilters)
{ {
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
@@ -64,7 +64,7 @@ namespace MQTTnet.Server
return _clientSessionsManager.SubscribeAsync(clientId, topicFilters); return _clientSessionsManager.SubscribeAsync(clientId, topicFilters);
} }


public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
public Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters)
{ {
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
@@ -118,7 +118,7 @@ namespace MQTTnet.Server


_cancellationTokenSource.Cancel(false); _cancellationTokenSource.Cancel(false);
_clientSessionsManager.StopAsync().ConfigureAwait(false);
await _clientSessionsManager.StopAsync().ConfigureAwait(false);


foreach (var adapter in _adapters) foreach (var adapter in _adapters)
{ {


+ 45
- 0
Source/MQTTnet/Server/MqttServerExtensions.cs View File

@@ -0,0 +1,45 @@
using System;
using System.Threading.Tasks;
using MQTTnet.Protocol;

namespace MQTTnet.Server
{
public static class MqttServerExtensions
{
public static Task SubscribeAsync(this IMqttServer server, string clientId, params TopicFilter[] topicFilters)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return server.SubscribeAsync(clientId, topicFilters);
}

public static Task SubscribeAsync(this IMqttServer server, string clientId, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topic == null) throw new ArgumentNullException(nameof(topic));

return server.SubscribeAsync(clientId, new TopicFilterBuilder().WithTopic(topic).WithQualityOfServiceLevel(qualityOfServiceLevel).Build());
}

public static Task SubscribeAsync(this IMqttServer server, string clientId, string topic)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topic == null) throw new ArgumentNullException(nameof(topic));

return server.SubscribeAsync(clientId, new TopicFilterBuilder().WithTopic(topic).Build());
}

public static Task UnsubscribeAsync(this IMqttServer server, string clientId, params string[] topicFilters)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return server.UnsubscribeAsync(clientId, topicFilters);
}
}
}

+ 140
- 60
Tests/MQTTnet.Core.Tests/MqttServerTests.cs View File

@@ -17,7 +17,7 @@ using MQTTnet.Server;
namespace MQTTnet.Tests namespace MQTTnet.Tests
{ {
[TestClass] [TestClass]
public class MqttServerTests
public partial class MqttServerTests
{ {
[TestMethod] [TestMethod]
public void MqttServer_PublishSimple_AtMostOnce() public void MqttServer_PublishSimple_AtMostOnce()
@@ -263,6 +263,91 @@ namespace MQTTnet.Tests
} }
} }


[TestMethod]
public async Task MqttServer_No_Messages_If_No_Subscription()
{
var server = new MqttFactory().CreateMqttServer();
try
{
await server.StartAsync(new MqttServerOptions());

var client = new MqttFactory().CreateMqttClient();
var receivedMessages = new List<MqttApplicationMessage>();

var options = new MqttClientOptionsBuilder()
.WithTcpServer("localhost").Build();

client.Connected += async (s, e) =>
{
await client.PublishAsync("Connected");
};

client.ApplicationMessageReceived += (s, e) =>
{
lock (receivedMessages)
{
receivedMessages.Add(e.ApplicationMessage);
}
};

await client.ConnectAsync(options);

await Task.Delay(500);

await client.PublishAsync("Hello");

await Task.Delay(500);
Assert.AreEqual(0, receivedMessages.Count);
}
finally
{
await server.StopAsync();
}
}

[TestMethod]
public async Task MqttServer_Set_Subscription_At_Server()
{
var server = new MqttFactory().CreateMqttServer();
try
{
await server.StartAsync(new MqttServerOptions());
server.ClientConnected += async (s, e) =>
{
await server.SubscribeAsync(e.ClientId, "topic1");
};

var client = new MqttFactory().CreateMqttClient();
var receivedMessages = new List<MqttApplicationMessage>();

var options = new MqttClientOptionsBuilder()
.WithTcpServer("localhost").Build();

client.ApplicationMessageReceived += (s, e) =>
{
lock (receivedMessages)
{
receivedMessages.Add(e.ApplicationMessage);
}
};

await client.ConnectAsync(options);

await Task.Delay(500);

await client.PublishAsync("Hello");

await Task.Delay(500);

Assert.AreEqual(0, receivedMessages.Count);
}
finally
{
await server.StopAsync();
}
}

private static async Task Publish(IMqttClient c1, MqttApplicationMessage message) private static async Task Publish(IMqttClient c1, MqttApplicationMessage message)
{ {
for (int i = 0; i < 1000; i++) for (int i = 0; i < 1000; i++)
@@ -302,40 +387,29 @@ namespace MQTTnet.Tests
[TestMethod] [TestMethod]
public async Task MqttServer_Handle_Clean_Disconnect() public async Task MqttServer_Handle_Clean_Disconnect()
{ {
var s = new MqttFactory().CreateMqttServer();
try
using (var testSetup = new TestSetup())
{ {
var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder());

var clientConnectedCalled = 0; var clientConnectedCalled = 0;
var clientDisconnectedCalled = 0; var clientDisconnectedCalled = 0;


s.ClientConnected += (_, __) => clientConnectedCalled++;
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++;

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();
server.ClientConnected += (_, __) => Interlocked.Increment(ref clientConnectedCalled);
server.ClientDisconnected += (_, __) => Interlocked.Increment(ref clientDisconnectedCalled);
var c1 = await testSetup.ConnectClient(new MqttClientOptionsBuilder());


await s.StartAsync(new MqttServerOptions());
Assert.AreEqual(1, clientConnectedCalled);
Assert.AreEqual(0, clientDisconnectedCalled);


var c1 = new MqttFactory().CreateMqttClient();

await c1.ConnectAsync(clientOptions);

await Task.Delay(100);
await Task.Delay(500);


await c1.DisconnectAsync(); await c1.DisconnectAsync();


await Task.Delay(100);

await s.StopAsync();

await Task.Delay(100);
await Task.Delay(500);


Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
}
finally
{
await s.StopAsync();
Assert.AreEqual(1, clientConnectedCalled);
Assert.AreEqual(1, clientDisconnectedCalled);
} }
} }


@@ -385,7 +459,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_LotsOfRetainedMessages()
public async Task MqttServer_Lots_Of_Retained_Messages()
{ {
const int ClientCount = 100; const int ClientCount = 100;


@@ -745,58 +819,64 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_SameClientIdConnectDisconnectEventOrder()
public async Task MqttServer_Same_Client_Id_Connect_Disconnect_Event_Order()
{ {
var s = new MqttFactory().CreateMqttServer();
using (var testSetup = new TestSetup())
{
var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder());


var events = new List<string>();
var events = new List<string>();


s.ClientConnected += (_, __) =>
{
lock (events)
server.ClientConnected += (_, __) =>
{ {
events.Add("c");
}
};
lock (events)
{
events.Add("c");
}
};


s.ClientDisconnected += (_, __) =>
{
lock (events)
server.ClientDisconnected += (_, __) =>
{ {
events.Add("d");
}
};

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.WithClientId("same_id")
.Build();
lock (events)
{
events.Add("d");
}
};


await s.StartAsync(new MqttServerOptions());
var clientOptions = new MqttClientOptionsBuilder()
.WithClientId("same_id");


var c1 = new MqttFactory().CreateMqttClient();
var c2 = new MqttFactory().CreateMqttClient();
// c
var c1 = await testSetup.ConnectClient(clientOptions);
await Task.Delay(500);


await c1.ConnectAsync(clientOptions);
var flow = string.Join(string.Empty, events);
Assert.AreEqual("c", flow);


await Task.Delay(250);
// dc
var c2 = await testSetup.ConnectClient(clientOptions);


await c2.ConnectAsync(clientOptions);
await Task.Delay(500);


await Task.Delay(250);
flow = string.Join(string.Empty, events);
Assert.AreEqual("cdc", flow);


await c1.DisconnectAsync();
// nothing
await c1.DisconnectAsync();


await Task.Delay(250);
await Task.Delay(500);


await c2.DisconnectAsync();
// d
await c2.DisconnectAsync();


await Task.Delay(250);
await Task.Delay(500);


await s.StopAsync();
await server.StopAsync();


var flow = string.Join(string.Empty, events);
Assert.AreEqual("cdcd", flow);
flow = string.Join(string.Empty, events);
Assert.AreEqual("cdcd", flow);
}
} }


[TestMethod] [TestMethod]


+ 92
- 0
Tests/MQTTnet.Core.Tests/TestSetup.cs View File

@@ -0,0 +1,92 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Client;
using MQTTnet.Client.Options;
using MQTTnet.Diagnostics;
using MQTTnet.Server;

namespace MQTTnet.Tests
{
public class TestSetup : IDisposable
{
private readonly MqttFactory _mqttFactory = new MqttFactory();
private readonly List<IMqttClient> _clients = new List<IMqttClient>();
private readonly IMqttNetLogger _serverLogger = new MqttNetLogger("server");
private readonly IMqttNetLogger _clientLogger = new MqttNetLogger("client");

private IMqttServer _server;

private long _serverErrorsCount;
private long _clientErrorsCount;

public TestSetup()
{
_serverLogger.LogMessagePublished += (s, e) =>
{
if (e.TraceMessage.Level == MqttNetLogLevel.Error)
{
Interlocked.Increment(ref _serverErrorsCount);
}
};

_clientLogger.LogMessagePublished += (s, e) =>
{
if (e.TraceMessage.Level == MqttNetLogLevel.Error)
{
Interlocked.Increment(ref _clientErrorsCount);
}
};
}

public async Task<IMqttServer> StartServerAsync(MqttServerOptionsBuilder options)
{
if (_server != null)
{
throw new InvalidOperationException("Server already started.");
}

_server = _mqttFactory.CreateMqttServer(_serverLogger);
await _server.StartAsync(options.WithDefaultEndpointPort(1888).Build());

return _server;
}

public async Task<IMqttClient> ConnectClient(MqttClientOptionsBuilder options)
{
var client = _mqttFactory.CreateMqttClient(_clientLogger);
_clients.Add(client);

await client.ConnectAsync(options.WithTcpServer("localhost", 1888).Build());

return client;
}

public void ThrowIfLogErrors()
{
if (_serverErrorsCount > 0)
{
throw new Exception($"Server had {_serverErrorsCount} errors.");
}

if (_clientErrorsCount > 0)
{
throw new Exception($"Client(s) had {_clientErrorsCount} errors.");
}
}

public void Dispose()
{
ThrowIfLogErrors();

foreach (var mqttClient in _clients)
{
mqttClient.DisconnectAsync().GetAwaiter().GetResult();
mqttClient.Dispose();
}

_server.StopAsync().GetAwaiter().GetResult();
}
}
}

Loading…
Cancel
Save