Browse Source

Refactor async/await and ConcurrentDictionary usage.

release/3.x.x
Christian Kratky 6 years ago
parent
commit
0322660561
20 changed files with 406 additions and 242 deletions
  1. +4
    -4
      README.md
  2. +2
    -2
      Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs
  3. +4
    -9
      Source/MQTTnet/Client/MqttClient.cs
  4. +16
    -17
      Source/MQTTnet/Internal/AsyncAutoResetEvent.cs
  5. +1
    -1
      Source/MQTTnet/Internal/AsyncLock.cs
  6. +4
    -5
      Source/MQTTnet/Serializer/MqttPacketReader.cs
  7. +23
    -0
      Source/MQTTnet/Server/IMqttClientSession.cs
  8. +12
    -14
      Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs
  9. +38
    -20
      Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs
  10. +84
    -52
      Source/MQTTnet/Server/MqttClientSession.cs
  11. +7
    -7
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  12. +27
    -18
      Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs
  13. +6
    -4
      Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs
  14. +40
    -25
      Source/MQTTnet/Server/MqttRetainedMessagesManager.cs
  15. +6
    -4
      Source/MQTTnet/Server/MqttServer.cs
  16. +62
    -20
      Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs
  17. +6
    -36
      Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs
  18. +42
    -1
      Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs
  19. +13
    -1
      Tests/MQTTnet.TestApp.NetCore/Program.cs
  20. +9
    -2
      Tests/MQTTnet.TestApp.NetCore/ServerTest.cs

+ 4
- 4
README.md View File

@@ -18,7 +18,7 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov
* TLS 1.2 support for client and server (but not UWP servers) * TLS 1.2 support for client and server (but not UWP servers)
* Extensible communication channels (i.e. In-Memory, TCP, TCP+TLS, WS) * Extensible communication channels (i.e. In-Memory, TCP, TCP+TLS, WS)
* Lightweight (only the low level implementation of MQTT, no overhead) * Lightweight (only the low level implementation of MQTT, no overhead)
* Performance optimized (processing ~60.000 messages / second)*
* Performance optimized (processing ~70.000 messages / second)*
* Interfaces included for mocking and testing * Interfaces included for mocking and testing
* Access to internal trace messages * Access to internal trace messages
* Unit tested (~90 tests) * Unit tested (~90 tests)
@@ -50,14 +50,15 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov
* .NET Standard 1.3+ * .NET Standard 1.3+
* .NET Core 1.1+ * .NET Core 1.1+
* .NET Core App 1.1+ * .NET Core App 1.1+
* Universal Windows Platform (UWP) 10.0.10240+ (x86, x64, ARM, AnyCPU, Windows 10 IoT Core)
* .NET Framework 4.5.2+ (x86, x64, AnyCPU) * .NET Framework 4.5.2+ (x86, x64, AnyCPU)
* Mono 5.2+ * Mono 5.2+
* Universal Windows Platform (UWP) 10.0.10240+ (x86, x64, ARM, AnyCPU, Windows 10 IoT Core)
* Xamarin.Android 7.5+ * Xamarin.Android 7.5+
* Xamarin.iOS 10.14+ * Xamarin.iOS 10.14+


## Supported MQTT versions ## Supported MQTT versions


* 5.0.0 (planned)
* 3.1.1 * 3.1.1
* 3.1.0 * 3.1.0


@@ -79,8 +80,7 @@ This library is used in the following projects:


* MQTT Client Rx (Wrapper for Reactive Extensions, <https://github.com/1iveowl/MQTTClient.rx>) * MQTT Client Rx (Wrapper for Reactive Extensions, <https://github.com/1iveowl/MQTTClient.rx>)
* MQTT Tester (MQTT client test app for [Android](https://play.google.com/store/apps/details?id=com.liveowl.mqtttester) and [iOS](https://itunes.apple.com/us/app/mqtt-tester/id1278621826?mt=8)) * MQTT Tester (MQTT client test app for [Android](https://play.google.com/store/apps/details?id=com.liveowl.mqtttester) and [iOS](https://itunes.apple.com/us/app/mqtt-tester/id1278621826?mt=8))
* Wirehome (Open Source Home Automation system for .NET, <https://github.com/chkr1011/Wirehome>)

* HA4IoT (Open Source Home Automation system for .NET, <https://github.com/chkr1011/HA4IoT>)


If you use this library and want to see your project here please let me know. If you use this library and want to see your project here please let me know.




+ 2
- 2
Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs View File

@@ -82,11 +82,11 @@ namespace MQTTnet.Extensions.Rpc
timeoutCts.Cancel(false); timeoutCts.Cancel(false);
return result; return result;
} }
catch (TaskCanceledException taskCanceledException)
catch (OperationCanceledException exception)
{ {
if (timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) if (timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested)
{ {
throw new MqttCommunicationTimedOutException(taskCanceledException);
throw new MqttCommunicationTimedOutException(exception);
} }
else else
{ {


+ 4
- 9
Source/MQTTnet/Client/MqttClient.cs View File

@@ -271,21 +271,16 @@ namespace MQTTnet.Client


private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken)
{ {
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
cancellationToken.ThrowIfCancellationRequested();


_sendTracker.Restart(); _sendTracker.Restart();

return _adapter.SendPacketAsync(packet, cancellationToken); return _adapter.SendPacketAsync(packet, cancellationToken);
} }


private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket
{ {
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
cancellationToken.ThrowIfCancellationRequested();


_sendTracker.Restart(); _sendTracker.Restart();


@@ -524,7 +519,7 @@ namespace MQTTnet.Client
{ {
await task.ConfigureAwait(false); await task.ConfigureAwait(false);
} }
catch (TaskCanceledException)
catch (OperationCanceledException)
{ {
} }
} }


+ 16
- 17
Source/MQTTnet/Internal/AsyncAutoResetEvent.cs View File

@@ -11,8 +11,10 @@ namespace MQTTnet.Internal
private readonly LinkedList<TaskCompletionSource<bool>> _waiters = new LinkedList<TaskCompletionSource<bool>>(); private readonly LinkedList<TaskCompletionSource<bool>> _waiters = new LinkedList<TaskCompletionSource<bool>>();
private bool _isSignaled; private bool _isSignaled;


public AsyncAutoResetEvent() : this(false)
{ }
public AsyncAutoResetEvent()
: this(false)
{
}


public AsyncAutoResetEvent(bool signaled) public AsyncAutoResetEvent(bool signaled)
{ {
@@ -58,27 +60,24 @@ namespace MQTTnet.Internal
} }


var winner = await Task.WhenAny(tcs.Task, Task.Delay(timeout, cancellationToken)).ConfigureAwait(false); var winner = await Task.WhenAny(tcs.Task, Task.Delay(timeout, cancellationToken)).ConfigureAwait(false);
if (winner == tcs.Task)
var taskWasSignaled = winner == tcs.Task;
if (taskWasSignaled)
{ {
// The task was signaled.
return true; return true;
} }
else

// We timed-out; remove our reference to the task.
// This is an O(n) operation since waiters is a LinkedList<T>.
lock (_waiters)
{ {
// We timed-out; remove our reference to the task.
// This is an O(n) operation since waiters is a LinkedList<T>.
lock (_waiters)
_waiters.Remove(tcs);
if (winner.Status == TaskStatus.Canceled)
{ {
_waiters.Remove(tcs);
if (winner.Status == TaskStatus.Canceled)
{
throw new OperationCanceledException(cancellationToken);
}
else
{
throw new TimeoutException();
}
throw new OperationCanceledException(cancellationToken);
} }

throw new TimeoutException();
} }
} }




+ 1
- 1
Source/MQTTnet/Internal/AsyncLock.cs View File

@@ -17,7 +17,7 @@ namespace MQTTnet.Internal


public Task<IDisposable> LockAsync(CancellationToken cancellationToken) public Task<IDisposable> LockAsync(CancellationToken cancellationToken)
{ {
Task wait = _semaphore.WaitAsync(cancellationToken);
var wait = _semaphore.WaitAsync(cancellationToken);
return wait.IsCompleted ? return wait.IsCompleted ?
_releaser : _releaser :
wait.ContinueWith((_, state) => (IDisposable)state, wait.ContinueWith((_, state) => (IDisposable)state,


+ 4
- 5
Source/MQTTnet/Serializer/MqttPacketReader.cs View File

@@ -29,11 +29,7 @@ namespace MQTTnet.Serializer
var bytesRead = await channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); var bytesRead = await channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false);
if (bytesRead <= 0) if (bytesRead <= 0)
{ {
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}

cancellationToken.ThrowIfCancellationRequested();
ExceptionHelper.ThrowGracefulSocketClose(); ExceptionHelper.ThrowGracefulSocketClose();
} }


@@ -59,6 +55,8 @@ namespace MQTTnet.Serializer


while ((encodedByte & 128) != 0) while ((encodedByte & 128) != 0)
{ {
cancellationToken.ThrowIfCancellationRequested();

// Here the async/await pattern is not used becuase the overhead of context switches // Here the async/await pattern is not used becuase the overhead of context switches
// is too big for reading 1 byte in a row. We expect that the remaining data was sent // is too big for reading 1 byte in a row. We expect that the remaining data was sent
// directly after the initial bytes. If the client disconnects just in this moment we // directly after the initial bytes. If the client disconnects just in this moment we
@@ -83,6 +81,7 @@ namespace MQTTnet.Serializer
var readCount = channel.ReadAsync(buffer, 0, 1, cancellationToken).GetAwaiter().GetResult(); var readCount = channel.ReadAsync(buffer, 0, 1, cancellationToken).GetAwaiter().GetResult();
if (readCount <= 0) if (readCount <= 0)
{ {
cancellationToken.ThrowIfCancellationRequested();
ExceptionHelper.ThrowGracefulSocketClose(); ExceptionHelper.ThrowGracefulSocketClose();
} }




+ 23
- 0
Source/MQTTnet/Server/IMqttClientSession.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using MQTTnet.Adapter;
using MQTTnet.Packets;

namespace MQTTnet.Server
{
public interface IMqttClientSession : IDisposable
{
string ClientId { get; }
void FillStatus(MqttClientSessionStatus status);

void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket);
void ClearPendingApplicationMessages();
Task<bool> RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter);
void Stop(MqttClientDisconnectType disconnectType);

Task SubscribeAsync(IList<TopicFilter> topicFilters);
Task UnsubscribeAsync(IList<string> topicFilters);
}
}

+ 12
- 14
Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs View File

@@ -12,19 +12,17 @@ namespace MQTTnet.Server
private readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch(); private readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch();
private readonly Stopwatch _lastNonKeepAlivePacketReceivedTracker = new Stopwatch(); private readonly Stopwatch _lastNonKeepAlivePacketReceivedTracker = new Stopwatch();


private readonly IMqttClientSession _clientSession;
private readonly IMqttNetChildLogger _logger; private readonly IMqttNetChildLogger _logger;
private readonly string _clientId;
private readonly Action _callback;

private bool _isPaused; private bool _isPaused;
private Task _workerTask;

public MqttClientKeepAliveMonitor(string clientId, Action callback, IMqttNetChildLogger logger)
public MqttClientKeepAliveMonitor(IMqttClientSession clientSession, IMqttNetChildLogger logger)
{ {
if (logger == null) throw new ArgumentNullException(nameof(logger)); if (logger == null) throw new ArgumentNullException(nameof(logger));


_clientId = clientId;
_callback = callback;
_clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession));
_logger = logger.CreateChildLogger(nameof(MqttClientKeepAliveMonitor)); _logger = logger.CreateChildLogger(nameof(MqttClientKeepAliveMonitor));
} }


@@ -39,7 +37,7 @@ namespace MQTTnet.Server
return; return;
} }


_workerTask = Task.Run(() => RunAsync(keepAlivePeriod, cancellationToken), cancellationToken);
Task.Run(() => RunAsync(keepAlivePeriod, cancellationToken), cancellationToken);
} }


public void Pause() public void Pause()
@@ -74,9 +72,9 @@ namespace MQTTnet.Server
// Values described here: [MQTT-3.1.2-24]. // Values described here: [MQTT-3.1.2-24].
if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds > keepAlivePeriod * 1.5D) if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds > keepAlivePeriod * 1.5D)
{ {
_logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientId);
_callback();
_logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientSession.ClientId);
_clientSession.Stop(MqttClientDisconnectType.NotClean);
return; return;
} }


@@ -88,11 +86,11 @@ namespace MQTTnet.Server
} }
catch (Exception exception) catch (Exception exception)
{ {
_logger.Error(exception, "Client '{0}': Unhandled exception while checking keep alive timeouts.", _clientId);
_logger.Error(exception, "Client '{0}': Unhandled exception while checking keep alive timeouts.", _clientSession.ClientId);
} }
finally finally
{ {
_logger.Verbose("Client {0}: Stopped checking keep alive timeout.", _clientId);
_logger.Verbose("Client {0}: Stopped checking keep alive timeout.", _clientSession.ClientId);
} }
} }
} }


+ 38
- 20
Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs View File

@@ -1,5 +1,5 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Adapter; using MQTTnet.Adapter;
@@ -13,13 +13,13 @@ namespace MQTTnet.Server
{ {
public class MqttClientPendingPacketsQueue : IDisposable public class MqttClientPendingPacketsQueue : IDisposable
{ {
private readonly Queue<MqttBasePacket> _queue = new Queue<MqttBasePacket>();
private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent();

private readonly IMqttServerOptions _options; private readonly IMqttServerOptions _options;
private readonly MqttClientSession _clientSession; private readonly MqttClientSession _clientSession;
private readonly IMqttNetChildLogger _logger; private readonly IMqttNetChildLogger _logger;


private ConcurrentQueue<MqttBasePacket> _queue = new ConcurrentQueue<MqttBasePacket>();

public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger)
{ {
if (logger == null) throw new ArgumentNullException(nameof(logger)); if (logger == null) throw new ArgumentNullException(nameof(logger));
@@ -29,7 +29,16 @@ namespace MQTTnet.Server
_logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue)); _logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue));
} }


public int Count => _queue.Count;
public int Count
{
get
{
lock (_queue)
{
return _queue.Count;
}
}
}


public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken) public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken)
{ {
@@ -42,25 +51,29 @@ namespace MQTTnet.Server


Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken); Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken);
} }
public void Enqueue(MqttBasePacket packet) public void Enqueue(MqttBasePacket packet)
{ {
if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packet == null) throw new ArgumentNullException(nameof(packet));


if (_queue.Count >= _options.MaxPendingMessagesPerClient)
lock (_queue)
{ {
if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropNewMessage)
if (_queue.Count >= _options.MaxPendingMessagesPerClient)
{ {
return;
}
if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropNewMessage)
{
return;
}


if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage)
{
_queue.TryDequeue(out _);
if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage)
{
_queue.Dequeue();
}
} }
_queue.Enqueue(packet);
} }


_queue.Enqueue(packet);
_queueAutoResetEvent.Set(); _queueAutoResetEvent.Set();


_logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId); _logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId);
@@ -68,13 +81,14 @@ namespace MQTTnet.Server


public void Clear() public void Clear()
{ {
var newQueue = new ConcurrentQueue<MqttBasePacket>();
Interlocked.Exchange(ref _queue, newQueue);
lock (_queue)
{
_queue.Clear();
}
} }


public void Dispose() public void Dispose()
{ {
} }


private async Task SendQueuedPacketsAsync(IMqttChannelAdapter adapter, CancellationToken cancellationToken) private async Task SendQueuedPacketsAsync(IMqttChannelAdapter adapter, CancellationToken cancellationToken)
@@ -100,13 +114,17 @@ namespace MQTTnet.Server
MqttBasePacket packet = null; MqttBasePacket packet = null;
try try
{ {
if (_queue.IsEmpty)
lock (_queue)
{ {
await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false);
if (_queue.Count > 0)
{
packet = _queue.Dequeue();
}
} }


if (!_queue.TryDequeue(out packet))
if (packet == null)
{ {
await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false);
return; return;
} }


@@ -115,7 +133,7 @@ namespace MQTTnet.Server
return; return;
} }


await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false);
adapter.SendPacketAsync(packet, cancellationToken).GetAwaiter().GetResult();


_logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId);
} }


+ 84
- 52
Source/MQTTnet/Server/MqttClientSession.cs View File

@@ -12,7 +12,7 @@ using MQTTnet.Protocol;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
public class MqttClientSession : IDisposable
public class MqttClientSession : IMqttClientSession
{ {
private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();


@@ -47,7 +47,7 @@ namespace MQTTnet.Server


_logger = logger.CreateChildLogger(nameof(MqttClientSession)); _logger = logger.CreateChildLogger(nameof(MqttClientSession));


_keepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, () => Stop(MqttClientDisconnectType.NotClean), _logger);
_keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger);
_subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server); _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server);
_pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger); _pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger);
} }
@@ -89,7 +89,7 @@ namespace MQTTnet.Server
if (packet != null) if (packet != null)
{ {
_keepAliveMonitor.PacketReceived(packet); _keepAliveMonitor.PacketReceived(packet);
await ProcessReceivedPacketAsync(adapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false);
ProcessReceivedPacket(adapter, packet, _cancellationTokenSource.Token);
} }
} }
} }
@@ -102,7 +102,7 @@ namespace MQTTnet.Server
{ {
if (exception is MqttCommunicationClosedGracefullyException) if (exception is MqttCommunicationClosedGracefullyException)
{ {
_logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId); ;
_logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId);
} }
else else
{ {
@@ -113,7 +113,7 @@ namespace MQTTnet.Server
{ {
_logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId);
} }
Stop(MqttClientDisconnectType.NotClean); Stop(MqttClientDisconnectType.NotClean);
} }
finally finally
@@ -123,7 +123,7 @@ namespace MQTTnet.Server
_adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; _adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
_adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; _adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
} }
_adapter = null; _adapter = null;


_cancellationTokenSource?.Dispose(); _cancellationTokenSource?.Dispose();
@@ -149,7 +149,7 @@ namespace MQTTnet.Server


if (_willMessage != null && !_wasCleanDisconnect) if (_willMessage != null && !_wasCleanDisconnect)
{ {
_sessionsManager.EnqueueApplicationMessage(this, _willMessage);
_sessionsManager.EnqueueApplicationMessage(this, _willMessage.ToPublishPacket());
} }


_willMessage = null; _willMessage = null;
@@ -160,18 +160,24 @@ namespace MQTTnet.Server
} }
} }


public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket)
{ {
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));
if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket));


var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(applicationMessage);
var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(publishPacket.Topic, publishPacket.QualityOfServiceLevel);
if (!checkSubscriptionsResult.IsSubscribed) if (!checkSubscriptionsResult.IsSubscribed)
{ {
return; return;
} }


var publishPacket = applicationMessage.ToPublishPacket();
publishPacket.QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel;
publishPacket = new MqttPublishPacket
{
Topic = publishPacket.Topic,
Payload = publishPacket.Payload,
QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel,
Retain = false,
Dup = false
};


if (publishPacket.QualityOfServiceLevel > 0) if (publishPacket.QualityOfServiceLevel > 0)
{ {
@@ -184,15 +190,19 @@ namespace MQTTnet.Server
senderClientSession?.ClientId, senderClientSession?.ClientId,
ClientId, ClientId,
publishPacket.ToApplicationMessage()); publishPacket.ToApplicationMessage());
_options.ClientMessageQueueInterceptor?.Invoke(context); _options.ClientMessageQueueInterceptor?.Invoke(context);


if (!context.AcceptEnqueue || context.ApplicationMessage == null) if (!context.AcceptEnqueue || context.ApplicationMessage == null)
{ {
return; return;
} }

publishPacket.Topic = context.ApplicationMessage.Topic;
publishPacket.Payload = context.ApplicationMessage.Payload;
publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel;
} }
_pendingPacketsQueue.Enqueue(publishPacket); _pendingPacketsQueue.Enqueue(publishPacket);
} }


@@ -233,21 +243,29 @@ namespace MQTTnet.Server
_cancellationTokenSource?.Dispose(); _cancellationTokenSource?.Dispose();
} }


private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
private void ProcessReceivedPacket(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
{ {
if (packet is MqttPublishPacket publishPacket) if (packet is MqttPublishPacket publishPacket)
{ {
return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken);
HandleIncomingPublishPacket(adapter, publishPacket, cancellationToken);
return;
} }


if (packet is MqttPingReqPacket) if (packet is MqttPingReqPacket)
{ {
return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken).GetAwaiter().GetResult();
return;
} }


if (packet is MqttPubRelPacket pubRelPacket) if (packet is MqttPubRelPacket pubRelPacket)
{ {
return HandleIncomingPubRelPacketAsync(adapter, pubRelPacket, cancellationToken);
var responsePacket = new MqttPubCompPacket
{
PacketIdentifier = pubRelPacket.PacketIdentifier
};

adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult();
return;
} }


if (packet is MqttPubRecPacket pubRecPacket) if (packet is MqttPubRecPacket pubRecPacket)
@@ -257,40 +275,41 @@ namespace MQTTnet.Server
PacketIdentifier = pubRecPacket.PacketIdentifier PacketIdentifier = pubRecPacket.PacketIdentifier
}; };


return adapter.SendPacketAsync(responsePacket, cancellationToken);
adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult();
return;
} }


if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
{ {
// Discard message.
return Task.FromResult(0);
return;
} }


if (packet is MqttSubscribePacket subscribePacket) if (packet is MqttSubscribePacket subscribePacket)
{ {
return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken);
HandleIncomingSubscribePacket(adapter, subscribePacket, cancellationToken);
return;
} }


if (packet is MqttUnsubscribePacket unsubscribePacket) if (packet is MqttUnsubscribePacket unsubscribePacket)
{ {
return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken);
HandleIncomingUnsubscribePacket(adapter, unsubscribePacket, cancellationToken);
return;
} }


if (packet is MqttDisconnectPacket) if (packet is MqttDisconnectPacket)
{ {
Stop(MqttClientDisconnectType.Clean); Stop(MqttClientDisconnectType.Clean);
return Task.FromResult(0);
return;
} }


if (packet is MqttConnectPacket) if (packet is MqttConnectPacket)
{ {
Stop(MqttClientDisconnectType.NotClean); Stop(MqttClientDisconnectType.NotClean);
return Task.FromResult(0);
return;
} }


_logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); _logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet);
Stop(MqttClientDisconnectType.NotClean); Stop(MqttClientDisconnectType.NotClean);
return Task.FromResult(0);
} }


private void EnqueueSubscribedRetainedMessages(ICollection<TopicFilter> topicFilters) private void EnqueueSubscribedRetainedMessages(ICollection<TopicFilter> topicFilters)
@@ -298,14 +317,14 @@ namespace MQTTnet.Server
var retainedMessages = _retainedMessagesManager.GetSubscribedMessages(topicFilters); var retainedMessages = _retainedMessagesManager.GetSubscribedMessages(topicFilters);
foreach (var applicationMessage in retainedMessages) foreach (var applicationMessage in retainedMessages)
{ {
EnqueueApplicationMessage(null, applicationMessage);
EnqueueApplicationMessage(null, applicationMessage.ToPublishPacket());
} }
} }


private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
private void HandleIncomingSubscribePacket(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
{ {
var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket); var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket);
await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false);
adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).GetAwaiter().GetResult();


if (subscribeResult.CloseConnection) if (subscribeResult.CloseConnection)
{ {
@@ -316,30 +335,30 @@ namespace MQTTnet.Server
EnqueueSubscribedRetainedMessages(subscribePacket.TopicFilters); EnqueueSubscribedRetainedMessages(subscribePacket.TopicFilters);
} }


private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
private void HandleIncomingUnsubscribePacket(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
{ {
var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket);
return adapter.SendPacketAsync(unsubscribeResult, cancellationToken);
adapter.SendPacketAsync(unsubscribeResult, cancellationToken).GetAwaiter().GetResult();
} }


private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
private void HandleIncomingPublishPacket(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{ {
var applicationMessage = publishPacket.ToApplicationMessage();

switch (applicationMessage.QualityOfServiceLevel)
switch (publishPacket.QualityOfServiceLevel)
{ {
case MqttQualityOfServiceLevel.AtMostOnce: case MqttQualityOfServiceLevel.AtMostOnce:
{ {
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage);
return Task.FromResult(0);
HandleIncomingPublishPacketWithQoS0(publishPacket);
break;
} }
case MqttQualityOfServiceLevel.AtLeastOnce: case MqttQualityOfServiceLevel.AtLeastOnce:
{ {
return HandleIncomingPublishPacketWithQoS1(adapter, applicationMessage, publishPacket, cancellationToken);
HandleIncomingPublishPacketWithQoS1(adapter, publishPacket, cancellationToken);
break;
} }
case MqttQualityOfServiceLevel.ExactlyOnce: case MqttQualityOfServiceLevel.ExactlyOnce:
{ {
return HandleIncomingPublishPacketWithQoS2(adapter, applicationMessage, publishPacket, cancellationToken);
HandleIncomingPublishPacketWithQoS2(adapter, publishPacket, cancellationToken);
break;
} }
default: default:
{ {
@@ -348,27 +367,40 @@ namespace MQTTnet.Server
} }
} }


private Task HandleIncomingPublishPacketWithQoS1(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket)
{ {
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage);

var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier };
return adapter.SendPacketAsync(response, cancellationToken);
_sessionsManager.EnqueueApplicationMessage(this, publishPacket);
} }


private Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
private void HandleIncomingPublishPacketWithQoS1(
IMqttChannelAdapter adapter,
MqttPublishPacket publishPacket,
CancellationToken cancellationToken)
{ {
// QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery)
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage);
_sessionsManager.EnqueueApplicationMessage(this, publishPacket);

var response = new MqttPubAckPacket
{
PacketIdentifier = publishPacket.PacketIdentifier
};


var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier };
return adapter.SendPacketAsync(response, cancellationToken);
adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult();
} }


private static Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken)
private void HandleIncomingPublishPacketWithQoS2(
IMqttChannelAdapter adapter,
MqttPublishPacket publishPacket,
CancellationToken cancellationToken)
{ {
var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier };
return adapter.SendPacketAsync(response, cancellationToken);
// QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery)
_sessionsManager.EnqueueApplicationMessage(this, publishPacket);

var response = new MqttPubRecPacket
{
PacketIdentifier = publishPacket.PacketIdentifier
};

adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult();
} }


private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) private void OnAdapterReadingPacketCompleted(object sender, EventArgs e)


+ 7
- 7
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -6,6 +6,7 @@ using System.Threading.Tasks;
using MQTTnet.Adapter; using MQTTnet.Adapter;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Exceptions; using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets; using MQTTnet.Packets;
using MQTTnet.Protocol; using MQTTnet.Protocol;


@@ -41,7 +42,7 @@ namespace MQTTnet.Server
Task.Factory.StartNew(() => ProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); Task.Factory.StartNew(() => ProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default);
} }


public Task StopAsync()
public void Stop()
{ {
foreach (var session in _sessions) foreach (var session in _sessions)
{ {
@@ -49,7 +50,6 @@ namespace MQTTnet.Server
} }


_sessions.Clear(); _sessions.Clear();
return Task.FromResult(0);
} }


public Task StartSession(IMqttChannelAdapter clientAdapter) public Task StartSession(IMqttChannelAdapter clientAdapter)
@@ -71,11 +71,11 @@ namespace MQTTnet.Server
return Task.FromResult((IList<IMqttClientSessionStatus>)result); return Task.FromResult((IList<IMqttClientSessionStatus>)result);
} }


public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket)
{ {
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));
if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket));


_messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken);
_messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, publishPacket), _cancellationToken);
} }


public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters) public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
@@ -118,7 +118,7 @@ namespace MQTTnet.Server
{ {
var enqueuedApplicationMessage = _messageQueue.Take(cancellationToken); var enqueuedApplicationMessage = _messageQueue.Take(cancellationToken);
var sender = enqueuedApplicationMessage.Sender; var sender = enqueuedApplicationMessage.Sender;
var applicationMessage = enqueuedApplicationMessage.ApplicationMessage;
var applicationMessage = enqueuedApplicationMessage.PublishPacket.ToApplicationMessage();


var interceptorContext = InterceptApplicationMessage(sender, applicationMessage); var interceptorContext = InterceptApplicationMessage(sender, applicationMessage);
if (interceptorContext != null) if (interceptorContext != null)
@@ -145,7 +145,7 @@ namespace MQTTnet.Server


foreach (var clientSession in _sessions.Values) foreach (var clientSession in _sessions.Values)
{ {
clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage);
clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage.ToPublishPacket());
} }
} }
catch (OperationCanceledException) catch (OperationCanceledException)


+ 27
- 18
Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs View File

@@ -1,5 +1,4 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using MQTTnet.Packets; using MQTTnet.Packets;
@@ -9,7 +8,7 @@ namespace MQTTnet.Server
{ {
public class MqttClientSubscriptionsManager public class MqttClientSubscriptionsManager
{ {
private readonly ConcurrentDictionary<string, MqttQualityOfServiceLevel> _subscriptions = new ConcurrentDictionary<string, MqttQualityOfServiceLevel>();
private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
private readonly IMqttServerOptions _options; private readonly IMqttServerOptions _options;
private readonly MqttServer _server; private readonly MqttServer _server;
private readonly string _clientId; private readonly string _clientId;
@@ -54,7 +53,11 @@ namespace MQTTnet.Server


if (interceptorContext.AcceptSubscription) if (interceptorContext.AcceptSubscription)
{ {
_subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
lock (_subscriptions)
{
_subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
}

_server.OnClientSubscribedTopic(_clientId, topicFilter); _server.OnClientSubscribedTopic(_clientId, topicFilter);
} }
} }
@@ -66,10 +69,14 @@ namespace MQTTnet.Server
{ {
if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket)); if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket));


foreach (var topicFilter in unsubscribePacket.TopicFilters)
lock (_subscriptions)
{ {
_subscriptions.TryRemove(topicFilter, out _);
_server.OnClientUnsubscribedTopic(_clientId, topicFilter);
foreach (var topicFilter in unsubscribePacket.TopicFilters)
{
_subscriptions.Remove(topicFilter);

_server.OnClientUnsubscribedTopic(_clientId, topicFilter);
}
} }


return new MqttUnsubAckPacket return new MqttUnsubAckPacket
@@ -78,19 +85,21 @@ namespace MQTTnet.Server
}; };
} }


public CheckSubscriptionsResult CheckSubscriptions(MqttApplicationMessage applicationMessage)
public CheckSubscriptionsResult CheckSubscriptions(string topic, MqttQualityOfServiceLevel qosLevel)
{ {
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));

var qosLevels = new HashSet<MqttQualityOfServiceLevel>(); var qosLevels = new HashSet<MqttQualityOfServiceLevel>();
foreach (var subscription in _subscriptions)

lock (_subscriptions)
{ {
if (!MqttTopicFilterComparer.IsMatch(applicationMessage.Topic, subscription.Key))
foreach (var subscription in _subscriptions)
{ {
continue;
}
if (!MqttTopicFilterComparer.IsMatch(topic, subscription.Key))
{
continue;
}


qosLevels.Add(subscription.Value);
qosLevels.Add(subscription.Value);
}
} }


if (qosLevels.Count == 0) if (qosLevels.Count == 0)
@@ -101,7 +110,7 @@ namespace MQTTnet.Server
}; };
} }


return CreateSubscriptionResult(applicationMessage, qosLevels);
return CreateSubscriptionResult(qosLevel, qosLevels);
} }


private static MqttSubscribeReturnCode ConvertToMaximumQoS(MqttQualityOfServiceLevel qualityOfServiceLevel) private static MqttSubscribeReturnCode ConvertToMaximumQoS(MqttQualityOfServiceLevel qualityOfServiceLevel)
@@ -122,12 +131,12 @@ namespace MQTTnet.Server
return interceptorContext; return interceptorContext;
} }


private static CheckSubscriptionsResult CreateSubscriptionResult(MqttApplicationMessage applicationMessage, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels)
private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels)
{ {
MqttQualityOfServiceLevel effectiveQoS; MqttQualityOfServiceLevel effectiveQoS;
if (subscribedQoSLevels.Contains(applicationMessage.QualityOfServiceLevel))
if (subscribedQoSLevels.Contains(qosLevel))
{ {
effectiveQoS = applicationMessage.QualityOfServiceLevel;
effectiveQoS = qosLevel;
} }
else if (subscribedQoSLevels.Count == 1) else if (subscribedQoSLevels.Count == 1)
{ {


+ 6
- 4
Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs View File

@@ -1,15 +1,17 @@
namespace MQTTnet.Server
using MQTTnet.Packets;

namespace MQTTnet.Server
{ {
public class MqttEnqueuedApplicationMessage public class MqttEnqueuedApplicationMessage
{ {
public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage)
public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttPublishPacket publishPacket)
{ {
Sender = sender; Sender = sender;
ApplicationMessage = applicationMessage;
PublishPacket = publishPacket;
} }


public MqttClientSession Sender { get; } public MqttClientSession Sender { get; }


public MqttApplicationMessage ApplicationMessage { get; }
public MqttPublishPacket PublishPacket { get; }
} }
} }

+ 40
- 25
Source/MQTTnet/Server/MqttRetainedMessagesManager.cs View File

@@ -1,5 +1,4 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -9,7 +8,8 @@ namespace MQTTnet.Server
{ {
public class MqttRetainedMessagesManager public class MqttRetainedMessagesManager
{ {
private readonly ConcurrentDictionary<string, MqttApplicationMessage> _messages = new ConcurrentDictionary<string, MqttApplicationMessage>();
private readonly Dictionary<string, MqttApplicationMessage> _messages = new Dictionary<string, MqttApplicationMessage>();

private readonly IMqttNetChildLogger _logger; private readonly IMqttNetChildLogger _logger;
private readonly IMqttServerOptions _options; private readonly IMqttServerOptions _options;


@@ -31,10 +31,13 @@ namespace MQTTnet.Server
{ {
var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false);


_messages.Clear();
foreach (var retainedMessage in retainedMessages)
lock (_messages)
{ {
_messages[retainedMessage.Topic] = retainedMessage;
_messages.Clear();
foreach (var retainedMessage in retainedMessages)
{
_messages[retainedMessage.Topic] = retainedMessage;
}
} }
} }
catch (Exception exception) catch (Exception exception)
@@ -61,17 +64,20 @@ namespace MQTTnet.Server
{ {
var retainedMessages = new List<MqttApplicationMessage>(); var retainedMessages = new List<MqttApplicationMessage>();


foreach (var retainedMessage in _messages.Values)
lock (_messages)
{ {
foreach (var topicFilter in topicFilters)
foreach (var retainedMessage in _messages.Values)
{ {
if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic))
foreach (var topicFilter in topicFilters)
{ {
continue;
}
if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic))
{
continue;
}


retainedMessages.Add(retainedMessage);
break;
retainedMessages.Add(retainedMessage);
break;
}
} }
} }


@@ -82,28 +88,31 @@ namespace MQTTnet.Server
{ {
var saveIsRequired = false; var saveIsRequired = false;


if (applicationMessage.Payload?.Length == 0)
{
saveIsRequired = _messages.TryRemove(applicationMessage.Topic, out _);
_logger.Info("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic);
}
else
lock (_messages)
{ {
if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage))
if (applicationMessage.Payload?.Length == 0)
{ {
_messages[applicationMessage.Topic] = applicationMessage;
saveIsRequired = true;
saveIsRequired = _messages.Remove(applicationMessage.Topic);
_logger.Info("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic);
} }
else else
{ {
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0]))
if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage))
{ {
_messages[applicationMessage.Topic] = applicationMessage; _messages[applicationMessage.Topic] = applicationMessage;
saveIsRequired = true; saveIsRequired = true;
} }
}
else
{
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0]))
{
_messages[applicationMessage.Topic] = applicationMessage;
saveIsRequired = true;
}
}


_logger.Info("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic);
_logger.Info("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic);
}
} }


if (!saveIsRequired) if (!saveIsRequired)
@@ -113,7 +122,13 @@ namespace MQTTnet.Server


if (saveIsRequired && _options.Storage != null) if (saveIsRequired && _options.Storage != null)
{ {
await _options.Storage.SaveRetainedMessagesAsync(_messages.Values.ToList()).ConfigureAwait(false);
List<MqttApplicationMessage> messages;
lock (_messages)
{
messages = _messages.Values.ToList();
}

await _options.Storage.SaveRetainedMessagesAsync(messages).ConfigureAwait(false);
} }
} }
} }


+ 6
- 4
Source/MQTTnet/Server/MqttServer.cs View File

@@ -5,6 +5,7 @@ using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Adapter; using MQTTnet.Adapter;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Internal;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
@@ -65,7 +66,7 @@ namespace MQTTnet.Server


if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started."); if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started.");


_clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage);
_clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage.ToPublishPacket());


return Task.FromResult(0); return Task.FromResult(0);
} }
@@ -104,22 +105,23 @@ namespace MQTTnet.Server
} }


_cancellationTokenSource.Cancel(false); _cancellationTokenSource.Cancel(false);
_cancellationTokenSource.Dispose();

foreach (var adapter in _adapters) foreach (var adapter in _adapters)
{ {
adapter.ClientAccepted -= OnClientAccepted; adapter.ClientAccepted -= OnClientAccepted;
await adapter.StopAsync().ConfigureAwait(false); await adapter.StopAsync().ConfigureAwait(false);
} }


await _clientSessionsManager.StopAsync().ConfigureAwait(false);
_clientSessionsManager.Stop();


_logger.Info("Stopped."); _logger.Info("Stopped.");
Stopped?.Invoke(this, EventArgs.Empty); Stopped?.Invoke(this, EventArgs.Empty);
} }
finally finally
{ {
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null; _cancellationTokenSource = null;

_retainedMessagesManager = null; _retainedMessagesManager = null;
_clientSessionsManager = null; _clientSessionsManager = null;
} }


+ 62
- 20
Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs View File

@@ -1,5 +1,9 @@
using System.Threading;
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Adapter;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Packets; using MQTTnet.Packets;
using MQTTnet.Server; using MQTTnet.Server;
@@ -12,39 +16,31 @@ namespace MQTTnet.Core.Tests
[TestMethod] [TestMethod]
public void KeepAlive_Timeout() public void KeepAlive_Timeout()
{ {
var timeoutCalledCount = 0;
var clientSession = new TestClientSession();
var monitor = new MqttClientKeepAliveMonitor(clientSession, new MqttNetLogger().CreateChildLogger());


var monitor = new MqttClientKeepAliveMonitor(string.Empty, delegate
{
timeoutCalledCount++;
}, new MqttNetLogger().CreateChildLogger(""));

Assert.AreEqual(0, timeoutCalledCount);
Assert.AreEqual(0, clientSession.StopCalledCount);


monitor.Start(1, CancellationToken.None); monitor.Start(1, CancellationToken.None);


Assert.AreEqual(0, timeoutCalledCount);
Assert.AreEqual(0, clientSession.StopCalledCount);


Thread.Sleep(2000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification. Thread.Sleep(2000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification.


Assert.AreEqual(1, timeoutCalledCount);
Assert.AreEqual(1, clientSession.StopCalledCount);
} }


[TestMethod] [TestMethod]
public void KeepAlive_NoTimeout() public void KeepAlive_NoTimeout()
{ {
var timeoutCalledCount = 0;

var monitor = new MqttClientKeepAliveMonitor(string.Empty, delegate
{
timeoutCalledCount++;
}, new MqttNetLogger().CreateChildLogger(""));
var clientSession = new TestClientSession();
var monitor = new MqttClientKeepAliveMonitor(clientSession, new MqttNetLogger().CreateChildLogger());


Assert.AreEqual(0, timeoutCalledCount);
Assert.AreEqual(0, clientSession.StopCalledCount);


monitor.Start(1, CancellationToken.None); monitor.Start(1, CancellationToken.None);


Assert.AreEqual(0, timeoutCalledCount);
Assert.AreEqual(0, clientSession.StopCalledCount);


// Simulate traffic. // Simulate traffic.
Thread.Sleep(1000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification. Thread.Sleep(1000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification.
@@ -53,11 +49,57 @@ namespace MQTTnet.Core.Tests
monitor.PacketReceived(new MqttPublishPacket()); monitor.PacketReceived(new MqttPublishPacket());
Thread.Sleep(1000); Thread.Sleep(1000);


Assert.AreEqual(0, timeoutCalledCount);
Assert.AreEqual(0, clientSession.StopCalledCount);


Thread.Sleep(2000); Thread.Sleep(2000);


Assert.AreEqual(1, timeoutCalledCount);
Assert.AreEqual(1, clientSession.StopCalledCount);
}

private class TestClientSession : IMqttClientSession
{
public string ClientId { get; }

public int StopCalledCount { get; set; }

public void FillStatus(MqttClientSessionStatus status)
{
throw new NotSupportedException();
}

public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket)
{
throw new NotSupportedException();
}

public void ClearPendingApplicationMessages()
{
throw new NotSupportedException();
}

public Task<bool> RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter)
{
throw new NotSupportedException();
}

public void Stop(MqttClientDisconnectType disconnectType)
{
StopCalledCount++;
}

public Task SubscribeAsync(IList<TopicFilter> topicFilters)
{
throw new NotSupportedException();
}

public Task UnsubscribeAsync(IList<string> topicFilters)
{
throw new NotSupportedException();
}

public void Dispose()
{
}
} }
} }
} }

+ 6
- 36
Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs View File

@@ -20,13 +20,7 @@ namespace MQTTnet.Core.Tests


sm.Subscribe(sp); sm.Subscribe(sp);


var pp = new MqttApplicationMessage
{
Topic = "A/B/C",
QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce
};

var result = sm.CheckSubscriptions(pp);
var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce);
Assert.IsTrue(result.IsSubscribed); Assert.IsTrue(result.IsSubscribed);
Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtMostOnce); Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtMostOnce);
} }
@@ -41,13 +35,7 @@ namespace MQTTnet.Core.Tests


sm.Subscribe(sp); sm.Subscribe(sp);


var pp = new MqttApplicationMessage
{
Topic = "A/B/C",
QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce
};

var result = sm.CheckSubscriptions(pp);
var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed); Assert.IsTrue(result.IsSubscribed);
Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtMostOnce); Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtMostOnce);
} }
@@ -63,13 +51,7 @@ namespace MQTTnet.Core.Tests


sm.Subscribe(sp); sm.Subscribe(sp);


var pp = new MqttApplicationMessage
{
Topic = "A/B/C",
QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce
};

var result = sm.CheckSubscriptions(pp);
var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed); Assert.IsTrue(result.IsSubscribed);
Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtLeastOnce); Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtLeastOnce);
} }
@@ -84,13 +66,7 @@ namespace MQTTnet.Core.Tests


sm.Subscribe(sp); sm.Subscribe(sp);


var pp = new MqttApplicationMessage
{
Topic = "A/B/X",
QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce
};

Assert.IsFalse(sm.CheckSubscriptions(pp).IsSubscribed);
Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);
} }


[TestMethod] [TestMethod]
@@ -103,19 +79,13 @@ namespace MQTTnet.Core.Tests


sm.Subscribe(sp); sm.Subscribe(sp);


var pp = new MqttApplicationMessage
{
Topic = "A/B/C",
QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce
};

Assert.IsTrue(sm.CheckSubscriptions(pp).IsSubscribed);
Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);


var up = new MqttUnsubscribePacket(); var up = new MqttUnsubscribePacket();
up.TopicFilters.Add("A/B/C"); up.TopicFilters.Add("A/B/C");
sm.Unsubscribe(up); sm.Unsubscribe(up);


Assert.IsFalse(sm.CheckSubscriptions(pp).IsSubscribed);
Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);
} }
} }
} }

+ 42
- 1
Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs View File

@@ -12,7 +12,48 @@ namespace MQTTnet.TestApp.NetCore
{ {
public static class PerformanceTest public static class PerformanceTest
{ {
public static void Run()
public static void RunClientOnly()
{
try
{
var options = new MqttClientOptions
{
ChannelOptions = new MqttClientTcpOptions
{
Server = "127.0.0.1"
},
CleanSession = true
};

var client = new MqttFactory().CreateMqttClient();
client.ConnectAsync(options).GetAwaiter().GetResult();

var message = CreateMessage();
var stopwatch = new Stopwatch();

for (var i = 0; i < 10; i++)
{
var sentMessagesCount = 0;

stopwatch.Restart();
while (stopwatch.ElapsedMilliseconds < 1000)
{
client.PublishAsync(message).GetAwaiter().GetResult();
sentMessagesCount++;
}

Console.WriteLine($"Sending {sentMessagesCount} messages per second. #" + (i + 1));

GC.Collect();
}
}
catch (Exception exception)
{
Console.WriteLine(exception);
}
}

public static void RunClientAndServer()
{ {
try try
{ {


+ 13
- 1
Tests/MQTTnet.TestApp.NetCore/Program.cs View File

@@ -22,6 +22,8 @@ namespace MQTTnet.TestApp.NetCore
Console.WriteLine("5 = Start public broker test"); Console.WriteLine("5 = Start public broker test");
Console.WriteLine("6 = Start server & client"); Console.WriteLine("6 = Start server & client");
Console.WriteLine("7 = Client flow test"); Console.WriteLine("7 = Client flow test");
Console.WriteLine("8 = Start performance test (client only)");
Console.WriteLine("9 = Start server (no trace)");


var pressedKey = Console.ReadKey(true); var pressedKey = Console.ReadKey(true);
if (pressedKey.KeyChar == '1') if (pressedKey.KeyChar == '1')
@@ -34,7 +36,7 @@ namespace MQTTnet.TestApp.NetCore
} }
else if (pressedKey.KeyChar == '3') else if (pressedKey.KeyChar == '3')
{ {
PerformanceTest.Run();
PerformanceTest.RunClientAndServer();
return; return;
} }
else if (pressedKey.KeyChar == '4') else if (pressedKey.KeyChar == '4')
@@ -53,6 +55,16 @@ namespace MQTTnet.TestApp.NetCore
{ {
Task.Run(ClientFlowTest.RunAsync); Task.Run(ClientFlowTest.RunAsync);
} }
else if (pressedKey.KeyChar == '8')
{
PerformanceTest.RunClientOnly();
return;
}
else if (pressedKey.KeyChar == '9')
{
ServerTest.RunEmptyServer();
return;
}


Thread.Sleep(Timeout.Infinite); Thread.Sleep(Timeout.Infinite);
} }


+ 9
- 2
Tests/MQTTnet.TestApp.NetCore/ServerTest.cs View File

@@ -8,12 +8,19 @@ namespace MQTTnet.TestApp.NetCore
{ {
public static class ServerTest public static class ServerTest
{ {
public static void RunEmptyServer()
{
var mqttServer = new MqttFactory().CreateMqttServer();
mqttServer.StartAsync(new MqttServerOptions()).GetAwaiter().GetResult();

Console.WriteLine("Press any key to exit.");
Console.ReadLine();
}

public static async Task RunAsync() public static async Task RunAsync()
{ {
try try
{ {
MqttNetConsoleLogger.ForwardToConsole();

var options = new MqttServerOptions var options = new MqttServerOptions
{ {
ConnectionValidator = p => ConnectionValidator = p =>


Loading…
Cancel
Save