From 6a2bded18184a8420fb18e88318dfc162ed09efd Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Sun, 27 Jan 2019 11:15:32 +0100 Subject: [PATCH] Refactor session and connection handling in server. Fix QoS level 2 issues. --- Build/MQTTnet.nuspec | 3 + README.md | 2 +- .../MqttConnectionContext.cs | 2 +- .../ManagedMqttClient.cs | 2 +- .../ManagedMqttClientExtensions.cs | 87 ++ Source/MQTTnet/Adapter/IMqttChannelAdapter.cs | 2 +- Source/MQTTnet/Adapter/MqttChannelAdapter.cs | 48 +- .../ApplicationMessagePublisherExtensions.cs | 81 -- Source/MQTTnet/Channel/IMqttChannel.cs | 2 +- ...ult.cs => MqttClientAuthenticateResult.cs} | 2 +- .../MqttClientConnectedEventArgs.cs | 6 +- Source/MQTTnet/Client/IMqttClient.cs | 9 +- Source/MQTTnet/Client/MqttClient.cs | 214 ++-- Source/MQTTnet/Client/MqttClientExtensions.cs | 114 ++ .../Options/MqttClientOptionsBuilder.cs | 5 + .../MQTTnet/Formatter/IMqttDataConverter.cs | 10 +- .../Formatter/V3/MqttV310DataConverter.cs | 40 +- .../Formatter/V5/MqttV500DataConverter.cs | 67 +- .../MQTTnet/IApplicationMessagePublisher.cs | 5 +- .../Implementations/MqttTcpChannel.Uwp.cs | 2 +- .../MQTTnet/Implementations/MqttTcpChannel.cs | 39 +- .../Implementations/MqttTcpServerListener.cs | 1 - .../Implementations/MqttWebSocketChannel.cs | 4 +- .../{TaskExtensions.cs => MqttTaskTimeout.cs} | 6 +- Source/MQTTnet/Internal/TestMqttChannel.cs | 2 +- .../MQTTnet/MqttApplicationMessageBuilder.cs | 17 +- .../PacketDispatcher/MqttPacketAwaiter.cs | 3 +- .../PacketDispatcher/MqttPacketDispatcher.cs | 28 +- Source/MQTTnet/Server/IMqttClientSession.cs | 7 +- .../Server/IMqttClientSessionStatus.cs | 33 - Source/MQTTnet/Server/IMqttServer.cs | 10 +- .../Server/IMqttServerConnectionValidator.cs | 2 +- Source/MQTTnet/Server/MqttClientConnection.cs | 568 +++++++++ .../Server/MqttClientDisconnectType.cs | 3 +- .../Server/MqttClientDisconnectedEventArgs.cs | 6 +- .../Server/MqttClientKeepAliveMonitor.cs | 21 +- Source/MQTTnet/Server/MqttClientSession.cs | 481 -------- ...ttClientSessionApplicationMessagesQueue.cs | 117 ++ .../MqttClientSessionPendingMessagesQueue.cs | 212 ---- .../MQTTnet/Server/MqttClientSessionStatus.cs | 52 - .../Server/MqttClientSessionsManager.cs | 233 ++-- .../Server/MqttClientSubscriptionsManager.cs | 43 +- .../Server/MqttEnqueuedApplicationMessage.cs | 4 +- .../Server/MqttRetainedMessagesManager.cs | 31 +- Source/MQTTnet/Server/MqttServer.cs | 20 +- .../MqttServerConnectionValidatorDelegate.cs | 2 +- .../Server/MqttServerEventDispatcher.cs | 4 +- Source/MQTTnet/Server/MqttServerExtensions.cs | 88 ++ .../Server/PrepareClientSessionResult.cs | 2 +- .../Server/Status/IMqttClientStatus.cs | 31 + .../Server/Status/IMqttSessionStatus.cs | 17 + .../MQTTnet/Server/Status/MqttClientStatus.cs | 43 + .../Server/Status/MqttSessionStatus.cs | 36 + .../MqttConnectionContextTest.cs | 4 +- .../ChannelAdapterBenchmark.cs | 2 +- .../MQTTnet.Benchmarks/SerializerBenchmark.cs | 2 +- Tests/MQTTnet.Core.Tests/ExtensionTests.cs | 13 +- .../{MqttClientTests.cs => Client_Tests.cs} | 16 +- .../TestEnvironment.cs} | 23 +- .../TestMqttCommunicationAdapter.cs | 29 +- .../TestMqttCommunicationAdapterFactory.cs | 3 +- .../{ => Mockups}/TestMqttServerAdapter.cs | 2 +- .../{ => Mockups}/TestServerExtensions.cs | 2 +- .../Mockups/TestServerStorage.cs | 22 + Tests/MQTTnet.Core.Tests/MqttClientTests.cs | 187 ++- .../MqttKeepAliveMonitorTests.cs | 11 +- Tests/MQTTnet.Core.Tests/MqttServerTests.cs | 1031 ----------------- .../MqttSubscriptionsManagerTests.cs | 12 +- .../MQTTnet.Core.Tests/Server_Status_Tests.cs | 159 +++ Tests/MQTTnet.Core.Tests/Server_Tests.cs | 970 ++++++++++++++++ .../MainPage.xaml | 6 +- .../MainPage.xaml.cs | 5 +- 72 files changed, 2913 insertions(+), 2455 deletions(-) delete mode 100644 Source/MQTTnet/ApplicationMessagePublisherExtensions.cs rename Source/MQTTnet/Client/Connecting/{MqttClientConnectResult.cs => MqttClientAuthenticateResult.cs} (79%) rename Source/MQTTnet/Internal/{TaskExtensions.cs => MqttTaskTimeout.cs} (84%) delete mode 100644 Source/MQTTnet/Server/IMqttClientSessionStatus.cs create mode 100644 Source/MQTTnet/Server/MqttClientConnection.cs delete mode 100644 Source/MQTTnet/Server/MqttClientSession.cs create mode 100644 Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs delete mode 100644 Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs delete mode 100644 Source/MQTTnet/Server/MqttClientSessionStatus.cs create mode 100644 Source/MQTTnet/Server/Status/IMqttClientStatus.cs create mode 100644 Source/MQTTnet/Server/Status/IMqttSessionStatus.cs create mode 100644 Source/MQTTnet/Server/Status/MqttClientStatus.cs create mode 100644 Source/MQTTnet/Server/Status/MqttSessionStatus.cs rename Tests/MQTTnet.Core.Tests/MQTTv5/{MqttClientTests.cs => Client_Tests.cs} (95%) rename Tests/MQTTnet.Core.Tests/{TestSetup.cs => Mockups/TestEnvironment.cs} (88%) rename Tests/MQTTnet.Core.Tests/{ => Mockups}/TestMqttCommunicationAdapter.cs (67%) rename Tests/MQTTnet.Core.Tests/{ => Mockups}/TestMqttCommunicationAdapterFactory.cs (92%) rename Tests/MQTTnet.Core.Tests/{ => Mockups}/TestMqttServerAdapter.cs (98%) rename Tests/MQTTnet.Core.Tests/{ => Mockups}/TestServerExtensions.cs (97%) create mode 100644 Tests/MQTTnet.Core.Tests/Mockups/TestServerStorage.cs delete mode 100644 Tests/MQTTnet.Core.Tests/MqttServerTests.cs create mode 100644 Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs create mode 100644 Tests/MQTTnet.Core.Tests/Server_Tests.cs diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 564e47e..2a8a763 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -11,10 +11,13 @@ false MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker). * [Core] Added support for MQTTv5 packages. +* [Core] Performance improvements (removed several exceptions). * [Client] Added new MQTTv5 features to options builder. * [Client] Added uniform API across all supported MQTT versions (BREAKING CHANGE!) * [Client] The client will now avoid sending an ACK if an exception has been thrown in message handler (thanks to @ramonsmits). +* [Client] Fixed issues in QoS 2 handling which leads to message loss. * [Server] Added support for MQTTv5 clients. The server will still return _success_ for all cases at the moment even if more granular codes are available. +* [Server] Fixed issues in QoS 2 handling which leads to message loss. * [Note] Due to MQTTv5 a lot of new classes were introduced. This required adding new namespaces as well. Most classes are backward compatible but new namespaces must be added. diff --git a/README.md b/README.md index e0e0bc2..b74c1fb 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov * Uniform API across all supported versions of the MQTT protocol * Interfaces included for mocking and testing * Access to internal trace messages -* Unit tested (~120 tests) +* Unit tested (~130 tests) \* Tested on local machine (Intel i7 8700K) with MQTTnet client and server running in the same process using the TCP channel. The app for verification is part of this repository and stored in _/Tests/MQTTnet.TestApp.NetCore_. diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index b24f6e1..999b340 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -105,7 +105,7 @@ namespace MQTTnet.AspNetCore return null; } - public async Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) + public async Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout, CancellationToken cancellationToken) { var buffer = PacketFormatterAdapter.Encode(packet).AsMemory(); var output = Connection.Transport.Output; diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs index 6fcbee9..b4eec1c 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs @@ -110,7 +110,7 @@ namespace MQTTnet.Extensions.ManagedClient return Task.FromResult(0); } - public async Task PublishAsync(MqttApplicationMessage applicationMessage) + public async Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientExtensions.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientExtensions.cs index ba5a16f..9a63565 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientExtensions.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientExtensions.cs @@ -1,5 +1,8 @@ using System; +using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; +using MQTTnet.Client.Publishing; using MQTTnet.Protocol; namespace MQTTnet.Extensions.ManagedClient @@ -35,5 +38,89 @@ namespace MQTTnet.Extensions.ManagedClient return managedClient.UnsubscribeAsync(topicFilters); } + + public static async Task PublishAsync(this IManagedMqttClient client, IEnumerable applicationMessages) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); + + foreach (var applicationMessage in applicationMessages) + { + await client.PublishAsync(applicationMessage).ConfigureAwait(false); + } + } + + public static Task PublishAsync(this IManagedMqttClient client, MqttApplicationMessage applicationMessage) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + + return client.PublishAsync(applicationMessage, CancellationToken.None); + } + + public static async Task PublishAsync(this IManagedMqttClient client, params MqttApplicationMessage[] applicationMessages) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); + + foreach (var applicationMessage in applicationMessages) + { + await client.PublishAsync(applicationMessage, CancellationToken.None).ConfigureAwait(false); + } + } + + public static Task PublishAsync(this IManagedMqttClient client, string topic) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(builder => builder + .WithTopic(topic)); + } + + public static Task PublishAsync(this IManagedMqttClient client, string topic, string payload) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(builder => builder + .WithTopic(topic) + .WithPayload(payload)); + } + + public static Task PublishAsync(this IManagedMqttClient client, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(builder => builder + .WithTopic(topic) + .WithPayload(payload) + .WithQualityOfServiceLevel(qualityOfServiceLevel)); + } + + public static Task PublishAsync(this IManagedMqttClient client, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, bool retain) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(builder => builder + .WithTopic(topic) + .WithPayload(payload) + .WithQualityOfServiceLevel(qualityOfServiceLevel) + .WithRetainFlag(retain)); + } + + public static Task PublishAsync(this IManagedMqttClient client, Func builder, CancellationToken cancellationToken) + { + var message = builder(new MqttApplicationMessageBuilder()).Build(); + return client.PublishAsync(message, cancellationToken); + } + + public static Task PublishAsync(this IManagedMqttClient client, Func builder) + { + var message = builder(new MqttApplicationMessageBuilder()).Build(); + return client.PublishAsync(message, CancellationToken.None); + } } } diff --git a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs index 6310664..f499322 100644 --- a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs @@ -20,7 +20,7 @@ namespace MQTTnet.Adapter Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken); - Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken); + Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout, CancellationToken cancellationToken); Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); } diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index ded2a1f..1bfcff0 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -53,11 +53,14 @@ namespace MQTTnet.Adapter try { - _logger.Verbose("Connecting [Timeout={0}]", timeout); - - await Internal.TaskExtensions - .TimeoutAfterAsync(ct => _channel.ConnectAsync(ct), timeout, cancellationToken) - .ConfigureAwait(false); + if (timeout == TimeSpan.Zero) + { + await _channel.ConnectAsync(cancellationToken).ConfigureAwait(false); + } + else + { + await MqttTaskTimeout.WaitAsync(t => _channel.ConnectAsync(t), timeout, cancellationToken).ConfigureAwait(false); + } } catch (Exception exception) { @@ -76,11 +79,15 @@ namespace MQTTnet.Adapter try { - _logger.Verbose("Disconnecting [Timeout={0}]", timeout); - - await Internal.TaskExtensions - .TimeoutAfterAsync(ct => _channel.DisconnectAsync(), timeout, cancellationToken) - .ConfigureAwait(false); + if (timeout == TimeSpan.Zero) + { + await _channel.DisconnectAsync(cancellationToken).ConfigureAwait(false); + } + else + { + await MqttTaskTimeout.WaitAsync( + t => _channel.DisconnectAsync(t), timeout, cancellationToken).ConfigureAwait(false); + } } catch (Exception exception) { @@ -93,13 +100,23 @@ namespace MQTTnet.Adapter } } - public async Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) + public async Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout, CancellationToken cancellationToken) { await _writerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); try { var packetData = PacketFormatterAdapter.Encode(packet); - await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false); + + if (timeout == TimeSpan.Zero) + { + await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false); + } + else + { + await MqttTaskTimeout.WaitAsync( + t => _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, t), timeout, cancellationToken).ConfigureAwait(false); + } + PacketFormatterAdapter.FreeBuffer(); _logger.Verbose("TX ({0} bytes) >>> {1}", packetData.Count, packet); @@ -126,14 +143,13 @@ namespace MQTTnet.Adapter try { ReceivedMqttPacket receivedMqttPacket; - - if (timeout > TimeSpan.Zero) + if (timeout == TimeSpan.Zero) { - receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfterAsync(ReceiveAsync, timeout, cancellationToken).ConfigureAwait(false); + receivedMqttPacket = await ReceiveAsync(cancellationToken).ConfigureAwait(false); } else { - receivedMqttPacket = await ReceiveAsync(cancellationToken).ConfigureAwait(false); + receivedMqttPacket = await MqttTaskTimeout.WaitAsync(ReceiveAsync, timeout, cancellationToken).ConfigureAwait(false); } if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) diff --git a/Source/MQTTnet/ApplicationMessagePublisherExtensions.cs b/Source/MQTTnet/ApplicationMessagePublisherExtensions.cs deleted file mode 100644 index ea0b83b..0000000 --- a/Source/MQTTnet/ApplicationMessagePublisherExtensions.cs +++ /dev/null @@ -1,81 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using MQTTnet.Client.Publishing; -using MQTTnet.Protocol; - -namespace MQTTnet -{ - public static class ApplicationMessagePublisherExtensions - { - public static async Task PublishAsync(this IApplicationMessagePublisher publisher, IEnumerable applicationMessages) - { - if (publisher == null) throw new ArgumentNullException(nameof(publisher)); - if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); - - foreach (var applicationMessage in applicationMessages) - { - await publisher.PublishAsync(applicationMessage).ConfigureAwait(false); - } - } - - public static async Task PublishAsync(this IApplicationMessagePublisher publisher, params MqttApplicationMessage[] applicationMessages) - { - if (publisher == null) throw new ArgumentNullException(nameof(publisher)); - if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); - - foreach (var applicationMessage in applicationMessages) - { - await publisher.PublishAsync(applicationMessage).ConfigureAwait(false); - } - } - - public static Task PublishAsync(this IApplicationMessagePublisher publisher, string topic) - { - if (publisher == null) throw new ArgumentNullException(nameof(publisher)); - if (topic == null) throw new ArgumentNullException(nameof(topic)); - - return publisher.PublishAsync(builder => builder - .WithTopic(topic)); - } - - public static Task PublishAsync(this IApplicationMessagePublisher publisher, string topic, string payload) - { - if (publisher == null) throw new ArgumentNullException(nameof(publisher)); - if (topic == null) throw new ArgumentNullException(nameof(topic)); - - return publisher.PublishAsync(builder => builder - .WithTopic(topic) - .WithPayload(payload)); - } - - public static Task PublishAsync(this IApplicationMessagePublisher publisher, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) - { - if (publisher == null) throw new ArgumentNullException(nameof(publisher)); - if (topic == null) throw new ArgumentNullException(nameof(topic)); - - return publisher.PublishAsync(builder => builder - .WithTopic(topic) - .WithPayload(payload) - .WithQualityOfServiceLevel(qualityOfServiceLevel)); - } - - public static Task PublishAsync(this IApplicationMessagePublisher publisher, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, bool retain) - { - if (publisher == null) throw new ArgumentNullException(nameof(publisher)); - if (topic == null) throw new ArgumentNullException(nameof(topic)); - - return publisher.PublishAsync(builder => builder - .WithTopic(topic) - .WithPayload(payload) - .WithQualityOfServiceLevel(qualityOfServiceLevel) - .WithRetainFlag(retain)); - } - - public static Task PublishAsync(this IApplicationMessagePublisher publisher, Func builder) - { - var message = builder(new MqttApplicationMessageBuilder()).Build(); - return publisher.PublishAsync(message); - } - } -} diff --git a/Source/MQTTnet/Channel/IMqttChannel.cs b/Source/MQTTnet/Channel/IMqttChannel.cs index 6050bcb..4dcb668 100644 --- a/Source/MQTTnet/Channel/IMqttChannel.cs +++ b/Source/MQTTnet/Channel/IMqttChannel.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Channel string Endpoint { get; } Task ConnectAsync(CancellationToken cancellationToken); - Task DisconnectAsync(); + Task DisconnectAsync(CancellationToken cancellationToken); Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); diff --git a/Source/MQTTnet/Client/Connecting/MqttClientConnectResult.cs b/Source/MQTTnet/Client/Connecting/MqttClientAuthenticateResult.cs similarity index 79% rename from Source/MQTTnet/Client/Connecting/MqttClientConnectResult.cs rename to Source/MQTTnet/Client/Connecting/MqttClientAuthenticateResult.cs index 05f53f0..db6e5be 100644 --- a/Source/MQTTnet/Client/Connecting/MqttClientConnectResult.cs +++ b/Source/MQTTnet/Client/Connecting/MqttClientAuthenticateResult.cs @@ -1,6 +1,6 @@ namespace MQTTnet.Client.Connecting { - public class MqttClientConnectResult + public class MqttClientAuthenticateResult { public bool IsSessionPresent { get; set; } diff --git a/Source/MQTTnet/Client/Connecting/MqttClientConnectedEventArgs.cs b/Source/MQTTnet/Client/Connecting/MqttClientConnectedEventArgs.cs index c052986..e7f97b9 100644 --- a/Source/MQTTnet/Client/Connecting/MqttClientConnectedEventArgs.cs +++ b/Source/MQTTnet/Client/Connecting/MqttClientConnectedEventArgs.cs @@ -4,11 +4,11 @@ namespace MQTTnet.Client.Connecting { public class MqttClientConnectedEventArgs : EventArgs { - public MqttClientConnectedEventArgs(MqttClientConnectResult connectResult) + public MqttClientConnectedEventArgs(MqttClientAuthenticateResult authenticateResult) { - ConnectResult = connectResult ?? throw new ArgumentNullException(nameof(connectResult)); + AuthenticateResult = authenticateResult ?? throw new ArgumentNullException(nameof(authenticateResult)); } - public MqttClientConnectResult ConnectResult { get; } + public MqttClientAuthenticateResult AuthenticateResult { get; } } } diff --git a/Source/MQTTnet/Client/IMqttClient.cs b/Source/MQTTnet/Client/IMqttClient.cs index 3c9ded3..0e00f00 100644 --- a/Source/MQTTnet/Client/IMqttClient.cs +++ b/Source/MQTTnet/Client/IMqttClient.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Client.Connecting; using MQTTnet.Client.Disconnecting; @@ -16,10 +17,10 @@ namespace MQTTnet.Client event EventHandler Connected; event EventHandler Disconnected; - Task ConnectAsync(IMqttClientOptions options); - Task DisconnectAsync(MqttClientDisconnectOptions options); + Task ConnectAsync(IMqttClientOptions options, CancellationToken cancellationToken); + Task DisconnectAsync(MqttClientDisconnectOptions options, CancellationToken cancellationToken); - Task SubscribeAsync(MqttClientSubscribeOptions options); - Task UnsubscribeAsync(MqttClientUnsubscribeOptions options); + Task SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken); + Task UnsubscribeAsync(MqttClientUnsubscribeOptions options, CancellationToken cancellationToken); } } \ No newline at end of file diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 90c5fa0..97cd3f2 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -12,7 +12,6 @@ using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; -using MQTTnet.Formatter; using MQTTnet.PacketDispatcher; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -29,10 +28,9 @@ namespace MQTTnet.Client private readonly IMqttClientAdapterFactory _adapterFactory; private readonly IMqttNetChildLogger _logger; - private CancellationTokenSource _cancellationTokenSource; + private CancellationTokenSource _backgroundCancellationTokenSource; private Task _packetReceiverTask; private Task _keepAlivePacketsSenderTask; - private Task _backgroundWorkerTask; private IMqttChannelAdapter _adapter; private bool _cleanDisconnectInitiated; @@ -47,15 +45,18 @@ namespace MQTTnet.Client } public event EventHandler Connected; + public event EventHandler Disconnected; public IMqttApplicationMessageHandler ReceivedApplicationMessageHandler { get; set; } + public event EventHandler ApplicationMessageReceived; public bool IsConnected { get; private set; } + public IMqttClientOptions Options { get; private set; } - public async Task ConnectAsync(IMqttClientOptions options) + public async Task ConnectAsync(IMqttClientOptions options, CancellationToken cancellationToken) { if (options == null) throw new ArgumentNullException(nameof(options)); if (options.ChannelOptions == null) throw new ArgumentException("ChannelOptions are not set."); @@ -69,37 +70,34 @@ namespace MQTTnet.Client _packetIdentifierProvider.Reset(); _packetDispatcher.Reset(); - _cancellationTokenSource = new CancellationTokenSource(); - var cancellationToken = _cancellationTokenSource.Token; + _backgroundCancellationTokenSource = new CancellationTokenSource(); + var backgroundCancellationToken = _backgroundCancellationTokenSource.Token; _disconnectGate = 0; var adapter = _adapterFactory.CreateClientAdapter(options, _logger); _adapter = adapter; - _logger.Verbose($"Trying to connect with server ({Options.ChannelOptions})."); - await _adapter.ConnectAsync(Options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); + _logger.Verbose($"Trying to connect with server '{options.ChannelOptions}' (Timeout={options.CommunicationTimeout})."); + await _adapter.ConnectAsync(options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); _logger.Verbose("Connection with server established."); - _packetReceiverTask = Task.Run(() => TryReceivePacketsAsync(cancellationToken), cancellationToken); + _packetReceiverTask = Task.Run(() => TryReceivePacketsAsync(backgroundCancellationToken), backgroundCancellationToken); - var connectResult = await AuthenticateAsync(adapter, options.WillMessage, cancellationToken).ConfigureAwait(false); - _logger.Verbose("MQTT connection with server established."); + var authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, cancellationToken).ConfigureAwait(false); _sendTracker.Restart(); if (Options.KeepAlivePeriod != TimeSpan.Zero) { - _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(cancellationToken), cancellationToken); + _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken); } - _backgroundWorkerTask = Task.Run(() => TryProcessReceivedPacketsAsync(cancellationToken), cancellationToken); - IsConnected = true; - Connected?.Invoke(this, new MqttClientConnectedEventArgs(connectResult)); + Connected?.Invoke(this, new MqttClientConnectedEventArgs(authenticateResult)); _logger.Info("Connected."); - return connectResult; + return authenticateResult; } catch (Exception exception) { @@ -114,16 +112,16 @@ namespace MQTTnet.Client } } - public async Task DisconnectAsync(MqttClientDisconnectOptions options) + public async Task DisconnectAsync(MqttClientDisconnectOptions options, CancellationToken cancellationToken) { try { _cleanDisconnectInitiated = true; - if (IsConnected && _cancellationTokenSource?.IsCancellationRequested == false) + if (IsConnected) { - var disconnectPacket = CreateDisconnectPacket(options); - await SendAsync(disconnectPacket, _cancellationTokenSource.Token).ConfigureAwait(false); + var disconnectPacket = _adapter.PacketFormatterAdapter.DataConverter.CreateDisconnectPacket(options); + await SendAsync(disconnectPacket, cancellationToken).ConfigureAwait(false); } } finally @@ -135,7 +133,7 @@ namespace MQTTnet.Client } } - public async Task SubscribeAsync(MqttClientSubscribeOptions options) + public async Task SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -144,11 +142,11 @@ namespace MQTTnet.Client var subscribePacket = _adapter.PacketFormatterAdapter.DataConverter.CreateSubscribePacket(options); subscribePacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - var subAckPacket = await SendAndReceiveAsync(subscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); + var subAckPacket = await SendAndReceiveAsync(subscribePacket, cancellationToken).ConfigureAwait(false); return _adapter.PacketFormatterAdapter.DataConverter.CreateClientSubscribeResult(subscribePacket, subAckPacket); } - public async Task UnsubscribeAsync(MqttClientUnsubscribeOptions options) + public async Task UnsubscribeAsync(MqttClientUnsubscribeOptions options, CancellationToken cancellationToken) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -157,11 +155,11 @@ namespace MQTTnet.Client var unsubscribePacket = _adapter.PacketFormatterAdapter.DataConverter.CreateUnsubscribePacket(options); unsubscribePacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - var unsubAckPacket = await SendAndReceiveAsync(unsubscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); + var unsubAckPacket = await SendAndReceiveAsync(unsubscribePacket, cancellationToken).ConfigureAwait(false); return _adapter.PacketFormatterAdapter.DataConverter.CreateClientUnsubscribeResult(unsubscribePacket, unsubAckPacket); } - public async Task PublishAsync(MqttApplicationMessage applicationMessage) + public Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); @@ -173,26 +171,15 @@ namespace MQTTnet.Client { case MqttQualityOfServiceLevel.AtMostOnce: { - // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] - await SendAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); - return new MqttClientPublishResult(); + return PublishAtMostOnce(publishPacket, cancellationToken); } case MqttQualityOfServiceLevel.AtLeastOnce: { - publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - var response = await SendAndReceiveAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); - - var result = new MqttClientPublishResult(); - if (response.ReasonCode != null) - { - result.ReasonCode = (MqttClientPublishReasonCode)response.ReasonCode; - } - - return result; + return PublishAtLeastOnceAsync(publishPacket, cancellationToken); } case MqttQualityOfServiceLevel.ExactlyOnce: { - return await PublishExactlyOnceAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); + return PublishExactlyOnceAsync(publishPacket, cancellationToken); } default: { @@ -203,15 +190,15 @@ namespace MQTTnet.Client public void Dispose() { - _cancellationTokenSource?.Cancel(false); - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; + _backgroundCancellationTokenSource?.Cancel(false); + _backgroundCancellationTokenSource?.Dispose(); + _backgroundCancellationTokenSource = null; _adapter?.Dispose(); _adapter = null; } - private async Task AuthenticateAsync(IMqttChannelAdapter channelAdapter, MqttApplicationMessage willApplicationMessage, CancellationToken cancellationToken) + private async Task AuthenticateAsync(IMqttChannelAdapter channelAdapter, MqttApplicationMessage willApplicationMessage, CancellationToken cancellationToken) { var connectPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnectPacket( willApplicationMessage, @@ -225,6 +212,8 @@ namespace MQTTnet.Client throw new MqttConnectingFailedException(result.ResultCode); } + _logger.Verbose("Authenticated MQTT connection with server established."); + return result; } @@ -250,12 +239,12 @@ namespace MQTTnet.Client { if (_adapter != null) { + _logger.Verbose("Disconnecting [Timeout={0}]", Options.CommunicationTimeout); await _adapter.DisconnectAsync(Options.CommunicationTimeout, CancellationToken.None).ConfigureAwait(false); } await WaitForTaskAsync(_packetReceiverTask, sender).ConfigureAwait(false); await WaitForTaskAsync(_keepAlivePacketsSenderTask, sender).ConfigureAwait(false); - await WaitForTaskAsync(_backgroundWorkerTask, sender).ConfigureAwait(false); _logger.Verbose("Disconnected from adapter."); } @@ -279,12 +268,12 @@ namespace MQTTnet.Client { try { - if (_cancellationTokenSource?.IsCancellationRequested == true) + if (_backgroundCancellationTokenSource?.IsCancellationRequested == true) { return; } - _cancellationTokenSource?.Cancel(false); + _backgroundCancellationTokenSource?.Cancel(false); } catch (Exception exception) { @@ -295,11 +284,14 @@ namespace MQTTnet.Client private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromResult(0); + } _sendTracker.Restart(); - return _adapter.SendPacketAsync(packet, cancellationToken); + return _adapter.SendPacketAsync(packet, Options.CommunicationTimeout, cancellationToken); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket @@ -318,7 +310,7 @@ namespace MQTTnet.Client { try { - await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false); + await _adapter.SendPacketAsync(requestPacket, Options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); return await packetAwaiter.WaitOneAsync(Options.CommunicationTimeout).ConfigureAwait(false); } catch (MqttCommunicationTimedOutException) @@ -398,12 +390,17 @@ namespace MQTTnet.Client return; } - if (packet == null && !DisconnectIsPending()) + if (packet == null) { - await DisconnectInternalAsync(_packetReceiverTask, null).ConfigureAwait(false); + if (!DisconnectIsPending()) + { + await DisconnectInternalAsync(_packetReceiverTask, null).ConfigureAwait(false); + } + + return; } - _packetDispatcher.Dispatch(packet); + await TryProcessReceivedPacketAsync(packet, cancellationToken).ConfigureAwait(false); } } catch (Exception exception) @@ -446,17 +443,25 @@ namespace MQTTnet.Client { await TryProcessReceivedPublishPacketAsync(publishPacket, cancellationToken).ConfigureAwait(false); } + else if (packet is MqttPubRelPacket pubRelPacket) + { + await SendAsync(new MqttPubCompPacket + { + PacketIdentifier = pubRelPacket.PacketIdentifier, + ReasonCode = MqttPubCompReasonCode.Success + }, cancellationToken).ConfigureAwait(false); + } else if (packet is MqttPingReqPacket) { await SendAsync(new MqttPingRespPacket(), cancellationToken).ConfigureAwait(false); } else if (packet is MqttDisconnectPacket) { - await DisconnectAsync(null).ConfigureAwait(false); + await DisconnectAsync(null, cancellationToken).ConfigureAwait(false); } else { - throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + _packetDispatcher.Dispatch(packet); } } catch (Exception exception) @@ -507,22 +512,15 @@ namespace MQTTnet.Client } else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) { + await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); + var pubRecPacket = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier, ReasonCode = MqttPubRecReasonCode.Success }; - var pubRelPacket = await SendAndReceiveAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); - - // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) - await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); - - await SendAsync(new MqttPubCompPacket - { - PacketIdentifier = pubRelPacket.PacketIdentifier, - ReasonCode = MqttPubCompReasonCode.Success - }, cancellationToken).ConfigureAwait(false); + await SendAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); } else { @@ -535,29 +533,18 @@ namespace MQTTnet.Client } } - private async Task TryProcessReceivedPacketsAsync(CancellationToken cancellationToken) + private async Task PublishAtMostOnce(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - try - { - while (!cancellationToken.IsCancellationRequested) - { - var packet = _packetDispatcher.Take(cancellationToken); - - if (cancellationToken.IsCancellationRequested) - { - return; - } + // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] + await SendAsync(publishPacket, cancellationToken).ConfigureAwait(false); + return _adapter.PacketFormatterAdapter.DataConverter.CreatePublishResult(null); + } - await TryProcessReceivedPacketAsync(packet, cancellationToken).ConfigureAwait(false); - } - } - catch (OperationCanceledException) - { - } - catch (Exception exception) - { - _logger.Error(exception, "Error while processing packet."); - } + private async Task PublishAtLeastOnceAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) + { + publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); + var response = await SendAndReceiveAsync(publishPacket, cancellationToken).ConfigureAwait(false); + return _adapter.PacketFormatterAdapter.DataConverter.CreatePublishResult(response); } private async Task PublishExactlyOnceAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) @@ -566,8 +553,6 @@ namespace MQTTnet.Client var pubRecPacket = await SendAndReceiveAsync(publishPacket, cancellationToken).ConfigureAwait(false); - // TODO: Check response code. - var pubRelPacket = new MqttPubRelPacket { PacketIdentifier = publishPacket.PacketIdentifier, @@ -576,36 +561,9 @@ namespace MQTTnet.Client var pubCompPacket = await SendAndReceiveAsync(pubRelPacket, cancellationToken).ConfigureAwait(false); - // TODO: Check response code. - - var result = new MqttClientPublishResult(); - - if (pubRecPacket.ReasonCode != null) - { - result.ReasonCode = (MqttClientPublishReasonCode)pubRecPacket.ReasonCode; - } - - return result; + return _adapter.PacketFormatterAdapter.DataConverter.CreatePublishResult(pubRecPacket, pubCompPacket); } - ////private void StartReceivingPackets(CancellationToken cancellationToken) - ////{ - //// _packetReceiverTask = Task.Factory.StartNew( - //// () => TryReceivePacketsAsync(cancellationToken), - //// cancellationToken, - //// TaskCreationOptions.LongRunning, - //// TaskScheduler.Default).Unwrap(); - ////} - - ////private void StartSendingKeepAliveMessages(CancellationToken cancellationToken) - ////{ - //// _keepAlivePacketsSenderTask = Task.Factory.StartNew( - //// () => TrySendKeepAliveMessagesAsync(cancellationToken), - //// cancellationToken, - //// TaskCreationOptions.LongRunning, - //// TaskScheduler.Default).Unwrap(); - ////} - private Task HandleReceivedApplicationMessageAsync(MqttPublishPacket publishPacket) { var applicationMessage = _adapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket); @@ -647,31 +605,5 @@ namespace MQTTnet.Client { return Interlocked.CompareExchange(ref _disconnectGate, 1, 0) != 0; } - - private MqttDisconnectPacket CreateDisconnectPacket(MqttClientDisconnectOptions options) - { - var packet = new MqttDisconnectPacket(); - - if (_adapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500) - { - if (options == null) - { - packet.ReasonCode = MqttDisconnectReasonCode.NormalDisconnection; - } - else - { - packet.ReasonCode = (MqttDisconnectReasonCode)options.ReasonCode; - } - } - else - { - if (options != null) - { - throw new MqttProtocolViolationException("Reason codes for disconnect are only supported for MQTTv5."); - } - } - - return packet; - } } } diff --git a/Source/MQTTnet/Client/MqttClientExtensions.cs b/Source/MQTTnet/Client/MqttClientExtensions.cs index ef6f5ab..752763d 100644 --- a/Source/MQTTnet/Client/MqttClientExtensions.cs +++ b/Source/MQTTnet/Client/MqttClientExtensions.cs @@ -1,5 +1,11 @@ using System; +using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; +using MQTTnet.Client.Connecting; +using MQTTnet.Client.Disconnecting; +using MQTTnet.Client.Options; +using MQTTnet.Client.Publishing; using MQTTnet.Client.Receiving; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; @@ -98,5 +104,113 @@ namespace MQTTnet.Client return client.UnsubscribeAsync(options); } + + public static Task ConnectAsync(this IMqttClient client, IMqttClientOptions options) + { + return client.ConnectAsync(options, CancellationToken.None); + } + + public static Task DisconnectAsync(this IMqttClient client, MqttClientDisconnectOptions options) + { + return client.DisconnectAsync(options, CancellationToken.None); + } + + public static Task SubscribeAsync(this IMqttClient client, MqttClientSubscribeOptions options) + { + return client.SubscribeAsync(options, CancellationToken.None); + } + + public static Task UnsubscribeAsync(this IMqttClient client, MqttClientUnsubscribeOptions options) + { + return client.UnsubscribeAsync(options, CancellationToken.None); + } + + public static Task PublishAsync(this IMqttClient client, MqttApplicationMessage applicationMessage) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + + return client.PublishAsync(applicationMessage, CancellationToken.None); + } + + public static async Task PublishAsync(this IMqttClient client, IEnumerable applicationMessages) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); + + foreach (var applicationMessage in applicationMessages) + { + await client.PublishAsync(applicationMessage).ConfigureAwait(false); + } + } + + public static Task PublishAsync(this IMqttClient client, string topic) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(new MqttApplicationMessageBuilder() + .WithTopic(topic) + .Build()); + } + + public static Task PublishAsync(this IMqttClient client, string topic, IEnumerable payload) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(new MqttApplicationMessageBuilder() + .WithTopic(topic) + .WithPayload(payload) + .Build()); + } + + public static Task PublishAsync(this IMqttClient client, string topic, string payload) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(new MqttApplicationMessageBuilder() + .WithTopic(topic) + .WithPayload(payload) + .Build()); + } + + public static Task PublishAsync(this IMqttClient client, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(new MqttApplicationMessageBuilder() + .WithTopic(topic) + .WithPayload(payload) + .WithQualityOfServiceLevel(qualityOfServiceLevel) + .Build()); + } + + public static Task PublishAsync(this IMqttClient client, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, bool retain) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(new MqttApplicationMessageBuilder() + .WithTopic(topic) + .WithPayload(payload) + .WithQualityOfServiceLevel(qualityOfServiceLevel) + .WithRetainFlag(retain) + .Build()); + } + + public static Task PublishAsync(this IMqttClient client, string topic, string payload, bool retain) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return client.PublishAsync(new MqttApplicationMessageBuilder() + .WithTopic(topic) + .WithPayload(payload) + .WithRetainFlag(retain) + .Build()); + } } } \ No newline at end of file diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs index 832c866..e3780c7 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs @@ -31,6 +31,11 @@ namespace MQTTnet.Client.Options return this; } + public MqttClientOptionsBuilder WithNoKeepAlive() + { + return WithKeepAlivePeriod(TimeSpan.Zero); + } + public MqttClientOptionsBuilder WithKeepAlivePeriod(TimeSpan value) { _options.KeepAlivePeriod = value; diff --git a/Source/MQTTnet/Formatter/IMqttDataConverter.cs b/Source/MQTTnet/Formatter/IMqttDataConverter.cs index 19c2a28..5b347e5 100644 --- a/Source/MQTTnet/Formatter/IMqttDataConverter.cs +++ b/Source/MQTTnet/Formatter/IMqttDataConverter.cs @@ -1,5 +1,7 @@ using MQTTnet.Client.Connecting; +using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; +using MQTTnet.Client.Publishing; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; using MQTTnet.Packets; @@ -12,7 +14,7 @@ namespace MQTTnet.Formatter MqttApplicationMessage CreateApplicationMessage(MqttPublishPacket publishPacket); - MqttClientConnectResult CreateClientConnectResult(MqttConnAckPacket connAckPacket); + MqttClientAuthenticateResult CreateClientConnectResult(MqttConnAckPacket connAckPacket); MqttConnectPacket CreateConnectPacket(MqttApplicationMessage willApplicationMessage, IMqttClientOptions options); @@ -23,5 +25,11 @@ namespace MQTTnet.Formatter MqttSubscribePacket CreateSubscribePacket(MqttClientSubscribeOptions options); MqttUnsubscribePacket CreateUnsubscribePacket(MqttClientUnsubscribeOptions options); + + MqttDisconnectPacket CreateDisconnectPacket(MqttClientDisconnectOptions options); + + MqttClientPublishResult CreatePublishResult(MqttPubAckPacket pubAckPacket); + + MqttClientPublishResult CreatePublishResult(MqttPubRecPacket pubRecPacket, MqttPubCompPacket pubCompPacket); } } diff --git a/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs b/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs index 1fad36b..49c25a0 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs @@ -1,7 +1,9 @@ using System; using System.Linq; using MQTTnet.Client.Connecting; +using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; +using MQTTnet.Client.Publishing; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; using MQTTnet.Exceptions; @@ -44,7 +46,7 @@ namespace MQTTnet.Formatter.V3 }; } - public MqttClientConnectResult CreateClientConnectResult(MqttConnAckPacket connAckPacket) + public MqttClientAuthenticateResult CreateClientConnectResult(MqttConnAckPacket connAckPacket) { if (connAckPacket == null) throw new ArgumentNullException(nameof(connAckPacket)); @@ -91,7 +93,7 @@ namespace MQTTnet.Formatter.V3 throw new MqttProtocolViolationException("Received unexpected return code."); } - return new MqttClientConnectResult + return new MqttClientAuthenticateResult { IsSessionPresent = connAckPacket.IsSessionPresent, ResultCode = resultCode @@ -173,5 +175,39 @@ namespace MQTTnet.Formatter.V3 return unsubscribePacket; } + + public MqttDisconnectPacket CreateDisconnectPacket(MqttClientDisconnectOptions options) + { + if (options != null) + { + throw new MqttProtocolViolationException("Reason codes for disconnect are only supported for MQTTv5."); + } + + return new MqttDisconnectPacket(); + } + + public MqttClientPublishResult CreatePublishResult(MqttPubAckPacket pubAckPacket) + { + return new MqttClientPublishResult + { + ReasonCode = MqttClientPublishReasonCode.Success + }; + } + + public MqttClientPublishResult CreatePublishResult(MqttPubRecPacket pubRecPacket, MqttPubCompPacket pubCompPacket) + { + if (pubRecPacket == null || pubCompPacket == null) + { + return new MqttClientPublishResult + { + ReasonCode = MqttClientPublishReasonCode.UnspecifiedError + }; + } + + return new MqttClientPublishResult + { + ReasonCode = MqttClientPublishReasonCode.Success + }; + } } } diff --git a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs index b5522e1..e3bcb4e 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs @@ -2,11 +2,14 @@ using System.Collections.Generic; using System.Linq; using MQTTnet.Client.Connecting; +using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; +using MQTTnet.Client.Publishing; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; using MQTTnet.Exceptions; using MQTTnet.Packets; +using MQTTnet.Protocol; namespace MQTTnet.Formatter.V5 { @@ -55,11 +58,11 @@ namespace MQTTnet.Formatter.V5 }; } - public MqttClientConnectResult CreateClientConnectResult(MqttConnAckPacket connAckPacket) + public MqttClientAuthenticateResult CreateClientConnectResult(MqttConnAckPacket connAckPacket) { if (connAckPacket == null) throw new ArgumentNullException(nameof(connAckPacket)); - return new MqttClientConnectResult + return new MqttClientAuthenticateResult { IsSessionPresent = connAckPacket.IsSessionPresent, ResultCode = (MqttClientConnectResultCode)connAckPacket.ReasonCode.Value @@ -158,5 +161,65 @@ namespace MQTTnet.Formatter.V5 return packet; } + + public MqttDisconnectPacket CreateDisconnectPacket(MqttClientDisconnectOptions options) + { + var packet = new MqttDisconnectPacket(); + + if (options == null) + { + packet.ReasonCode = MqttDisconnectReasonCode.NormalDisconnection; + } + else + { + packet.ReasonCode = (MqttDisconnectReasonCode)options.ReasonCode; + } + + return packet; + } + + public MqttClientPublishResult CreatePublishResult(MqttPubAckPacket pubAckPacket) + { + var result = new MqttClientPublishResult(); + + if (pubAckPacket?.ReasonCode != null) + { + result.ReasonCode = (MqttClientPublishReasonCode)pubAckPacket.ReasonCode; + } + + return result; + } + + public MqttClientPublishResult CreatePublishResult(MqttPubRecPacket pubRecPacket, MqttPubCompPacket pubCompPacket) + { + if (pubRecPacket == null || pubCompPacket == null) + { + return new MqttClientPublishResult + { + ReasonCode = MqttClientPublishReasonCode.UnspecifiedError + }; + } + + if (pubCompPacket.ReasonCode == MqttPubCompReasonCode.PacketIdentifierNotFound) + { + return new MqttClientPublishResult + { + ReasonCode = MqttClientPublishReasonCode.UnspecifiedError + }; + } + + var result = new MqttClientPublishResult + { + ReasonCode = MqttClientPublishReasonCode.Success + }; + + if (pubRecPacket.ReasonCode.HasValue) + { + // Both enums share the same values. + result.ReasonCode = (MqttClientPublishReasonCode)pubRecPacket.ReasonCode.Value; + } + + return result; + } } } diff --git a/Source/MQTTnet/IApplicationMessagePublisher.cs b/Source/MQTTnet/IApplicationMessagePublisher.cs index a7c0d4a..bf17718 100644 --- a/Source/MQTTnet/IApplicationMessagePublisher.cs +++ b/Source/MQTTnet/IApplicationMessagePublisher.cs @@ -1,10 +1,11 @@ -using System.Threading.Tasks; +using System.Threading; +using System.Threading.Tasks; using MQTTnet.Client.Publishing; namespace MQTTnet { public interface IApplicationMessagePublisher { - Task PublishAsync(MqttApplicationMessage applicationMessage); + Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken); } } diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs index fa20b08..ea5d6cc 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs @@ -81,7 +81,7 @@ namespace MQTTnet.Implementations CreateStreams(); } - public Task DisconnectAsync() + public Task DisconnectAsync(CancellationToken cancellationToken) { Dispose(); return Task.FromResult(0); diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 802463a..4a17885 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -72,7 +72,7 @@ namespace MQTTnet.Implementations CreateStream(sslStream); } - public Task DisconnectAsync() + public Task DisconnectAsync(CancellationToken cancellationToken) { Dispose(); return Task.FromResult(0); @@ -81,8 +81,13 @@ namespace MQTTnet.Implementations public async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { // Workaround for: https://github.com/dotnet/corefx/issues/24430 - using (cancellationToken.Register(() => _socket.Dispose())) + using (cancellationToken.Register(Dispose)) { + if (cancellationToken.IsCancellationRequested) + { + return 0; + } + return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } } @@ -90,8 +95,13 @@ namespace MQTTnet.Implementations public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { // Workaround for: https://github.com/dotnet/corefx/issues/24430 - using (cancellationToken.Register(() => _socket.Dispose())) + using (cancellationToken.Register(Dispose)) { + if (cancellationToken.IsCancellationRequested) + { + return; + } + await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); //await _stream.FlushAsync(cancellationToken); } @@ -99,16 +109,21 @@ namespace MQTTnet.Implementations public void Dispose() { - Cleanup(ref _stream, s => s.Dispose()); - Cleanup(ref _socket, s => - { - //if (s.Connected) - //{ - // s.Shutdown(SocketShutdown.Both); - //} + _socket = null; - s.Dispose(); - }); + // When the stream is disposed it will also close the socket and this will also dispose it. + // So there is no need to dispose the socket again. + // https://stackoverflow.com/questions/3601521/should-i-manually-dispose-the-socket-after-closing-it + try + { + _stream?.Dispose(); + } + catch (ObjectDisposedException) + { + } + catch (NullReferenceException) + { + } } private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index 3803689..c550ebb 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -3,7 +3,6 @@ using System; using System.Net; using System.Net.Security; using System.Net.Sockets; -using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; diff --git a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs index c6926dd..6240f25 100644 --- a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs +++ b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs @@ -85,7 +85,7 @@ namespace MQTTnet.Implementations _webSocket = clientWebSocket; } - public async Task DisconnectAsync() + public async Task DisconnectAsync(CancellationToken cancellationToken) { if (_webSocket == null) { @@ -94,7 +94,7 @@ namespace MQTTnet.Implementations if (_webSocket.State == WebSocketState.Open || _webSocket.State == WebSocketState.Connecting) { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); } Dispose(); diff --git a/Source/MQTTnet/Internal/TaskExtensions.cs b/Source/MQTTnet/Internal/MqttTaskTimeout.cs similarity index 84% rename from Source/MQTTnet/Internal/TaskExtensions.cs rename to Source/MQTTnet/Internal/MqttTaskTimeout.cs index 1356d97..ba4ec7f 100644 --- a/Source/MQTTnet/Internal/TaskExtensions.cs +++ b/Source/MQTTnet/Internal/MqttTaskTimeout.cs @@ -5,9 +5,9 @@ using MQTTnet.Exceptions; namespace MQTTnet.Internal { - public static class TaskExtensions + public static class MqttTaskTimeout { - public static async Task TimeoutAfterAsync(Func action, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task WaitAsync(Func action, TimeSpan timeout, CancellationToken cancellationToken) { if (action == null) throw new ArgumentNullException(nameof(action)); @@ -31,7 +31,7 @@ namespace MQTTnet.Internal } } - public static async Task TimeoutAfterAsync(Func> action, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task WaitAsync(Func> action, TimeSpan timeout, CancellationToken cancellationToken) { if (action == null) throw new ArgumentNullException(nameof(action)); diff --git a/Source/MQTTnet/Internal/TestMqttChannel.cs b/Source/MQTTnet/Internal/TestMqttChannel.cs index 08920e3..03f23b1 100644 --- a/Source/MQTTnet/Internal/TestMqttChannel.cs +++ b/Source/MQTTnet/Internal/TestMqttChannel.cs @@ -21,7 +21,7 @@ namespace MQTTnet.Internal return Task.FromResult(0); } - public Task DisconnectAsync() + public Task DisconnectAsync(CancellationToken cancellationToken) { return Task.FromResult(0); } diff --git a/Source/MQTTnet/MqttApplicationMessageBuilder.cs b/Source/MQTTnet/MqttApplicationMessageBuilder.cs index 1df8645..ffd9031 100644 --- a/Source/MQTTnet/MqttApplicationMessageBuilder.cs +++ b/Source/MQTTnet/MqttApplicationMessageBuilder.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; @@ -38,12 +39,24 @@ namespace MQTTnet return this; } - _payload = payload.ToArray(); + _payload = payload as byte[]; + + if (_payload == null) + { + _payload = payload.ToArray(); + } + return this; } public MqttApplicationMessageBuilder WithPayload(Stream payload) { + if (payload == null) + { + _payload = null; + return this; + } + return WithPayload(payload, payload.Length - payload.Position); } diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs index c30a6f4..19df6d4 100644 --- a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs @@ -1,6 +1,7 @@ using System; using System.Threading; using System.Threading.Tasks; +using MQTTnet.Exceptions; using MQTTnet.Packets; namespace MQTTnet.PacketDispatcher @@ -21,7 +22,7 @@ namespace MQTTnet.PacketDispatcher { using (var timeoutToken = new CancellationTokenSource(timeout)) { - timeoutToken.Token.Register(() => _taskCompletionSource.TrySetCanceled()); + timeoutToken.Token.Register(() => _taskCompletionSource.TrySetException(new MqttCommunicationTimedOutException())); var packet = await _taskCompletionSource.Task.ConfigureAwait(false); return (TPacket)packet; diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs index b4a26c2..26b36ec 100644 --- a/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs @@ -1,6 +1,6 @@ using System; using System.Collections.Concurrent; -using System.Threading; +using MQTTnet.Exceptions; using MQTTnet.Packets; namespace MQTTnet.PacketDispatcher @@ -9,8 +9,6 @@ namespace MQTTnet.PacketDispatcher { private readonly ConcurrentDictionary, IMqttPacketAwaiter> _packetAwaiters = new ConcurrentDictionary, IMqttPacketAwaiter>(); - private BlockingCollection _inboundPackagesQueue = new BlockingCollection(); - public void Dispatch(Exception exception) { foreach (var awaiter in _packetAwaiters) @@ -40,34 +38,14 @@ namespace MQTTnet.PacketDispatcher return; } - lock (_inboundPackagesQueue) - { - _inboundPackagesQueue.Add(packet); - } - } - - public MqttBasePacket Take(CancellationToken cancellationToken) - { - BlockingCollection inboundPackagesQueue; - lock (_inboundPackagesQueue) - { - inboundPackagesQueue = _inboundPackagesQueue; - } - - return inboundPackagesQueue.Take(cancellationToken); + throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); } public void Reset() { foreach (var awaiter in _packetAwaiters) { - awaiter.Value.Cancel(); - } - - lock (_inboundPackagesQueue) - { - _inboundPackagesQueue?.Dispose(); - _inboundPackagesQueue = new BlockingCollection(); + awaiter.Value.Cancel(); } _packetAwaiters.Clear(); diff --git a/Source/MQTTnet/Server/IMqttClientSession.cs b/Source/MQTTnet/Server/IMqttClientSession.cs index f1e3010..d7db2b1 100644 --- a/Source/MQTTnet/Server/IMqttClientSession.cs +++ b/Source/MQTTnet/Server/IMqttClientSession.cs @@ -1,12 +1,11 @@ -using System; -using System.Threading.Tasks; +using System.Threading.Tasks; namespace MQTTnet.Server { - public interface IMqttClientSession : IDisposable + public interface IMqttClientSession { string ClientId { get; } - Task StopAsync(MqttClientDisconnectType disconnectType); + Task StopAsync(); } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/IMqttClientSessionStatus.cs b/Source/MQTTnet/Server/IMqttClientSessionStatus.cs deleted file mode 100644 index d46db69..0000000 --- a/Source/MQTTnet/Server/IMqttClientSessionStatus.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System; -using System.Threading.Tasks; -using MQTTnet.Formatter; - -namespace MQTTnet.Server -{ - public interface IMqttClientSessionStatus - { - string ClientId { get; } - - string Endpoint { get; } - - bool IsConnected { get; } - - MqttProtocolVersion? ProtocolVersion { get; } - - TimeSpan LastPacketReceived { get; } - - TimeSpan LastNonKeepAlivePacketReceived { get; } - - long PendingApplicationMessagesCount { get; } - - long ReceivedApplicationMessagesCount { get; } - - long SentApplicationMessagesCount { get; } - - Task DisconnectAsync(); - - Task DeleteSessionAsync(); - - Task ClearPendingApplicationMessagesAsync(); - } -} diff --git a/Source/MQTTnet/Server/IMqttServer.cs b/Source/MQTTnet/Server/IMqttServer.cs index cccda2b..aceceee 100644 --- a/Source/MQTTnet/Server/IMqttServer.cs +++ b/Source/MQTTnet/Server/IMqttServer.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; +using MQTTnet.Server.Status; namespace MQTTnet.Server { @@ -16,13 +17,14 @@ namespace MQTTnet.Server IMqttServerOptions Options { get; } - Task> GetClientSessionsStatusAsync(); + Task> GetClientStatusAsync(); + Task> GetSessionStatusAsync(); - IList GetRetainedMessages(); + Task> GetRetainedMessagesAsync(); Task ClearRetainedMessagesAsync(); - Task SubscribeAsync(string clientId, IEnumerable topicFilters); - Task UnsubscribeAsync(string clientId, IEnumerable topicFilters); + Task SubscribeAsync(string clientId, ICollection topicFilters); + Task UnsubscribeAsync(string clientId, ICollection topicFilters); Task StartAsync(IMqttServerOptions options); Task StopAsync(); diff --git a/Source/MQTTnet/Server/IMqttServerConnectionValidator.cs b/Source/MQTTnet/Server/IMqttServerConnectionValidator.cs index bd38c83..68b5be6 100644 --- a/Source/MQTTnet/Server/IMqttServerConnectionValidator.cs +++ b/Source/MQTTnet/Server/IMqttServerConnectionValidator.cs @@ -4,6 +4,6 @@ namespace MQTTnet.Server { public interface IMqttServerConnectionValidator { - Task ValidateConnection(MqttConnectionValidatorContext context); + Task ValidateConnectionAsync(MqttConnectionValidatorContext context); } } diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs new file mode 100644 index 0000000..fa1d6a0 --- /dev/null +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -0,0 +1,568 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Adapter; +using MQTTnet.Client; +using MQTTnet.Diagnostics; +using MQTTnet.Exceptions; +using MQTTnet.Formatter; +using MQTTnet.PacketDispatcher; +using MQTTnet.Packets; +using MQTTnet.Protocol; +using MQTTnet.Server.Status; + +namespace MQTTnet.Server +{ + public class MqttClientSession + { + private readonly DateTime _createdTimestamp = DateTime.UtcNow; + + public MqttClientSession(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions) + { + ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + + SubscriptionsManager = new MqttClientSubscriptionsManager(clientId, eventDispatcher, serverOptions); + ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); + } + + public string ClientId { get; } + + public bool IsCleanSession { get; set; } = true; + + public MqttApplicationMessage WillMessage { get; set; } + + public MqttClientSubscriptionsManager SubscriptionsManager { get; } + + public MqttClientSessionApplicationMessagesQueue ApplicationMessagesQueue { get; } + + public void EnqueueApplicationMessage(MqttApplicationMessage applicationMessage, string senderClientId, bool isRetainedApplicationMessage) + { + var checkSubscriptionsResult = SubscriptionsManager.CheckSubscriptions(applicationMessage.Topic, applicationMessage.QualityOfServiceLevel); + if (!checkSubscriptionsResult.IsSubscribed) + { + return; + } + + ApplicationMessagesQueue.Enqueue(applicationMessage, senderClientId, checkSubscriptionsResult.QualityOfServiceLevel, isRetainedApplicationMessage); + } + + public async Task SubscribeAsync(ICollection topicFilters, MqttRetainedMessagesManager retainedMessagesManager) + { + await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); + var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); + foreach (var matchingRetainedMessage in matchingRetainedMessages) + { + EnqueueApplicationMessage(matchingRetainedMessage, null, true); + } + } + + public Task UnsubscribeAsync(IEnumerable topicFilters) + { + return SubscriptionsManager.UnsubscribeAsync(topicFilters); + } + + public void FillStatus(MqttSessionStatus status) + { + status.ClientId = ClientId; + status.CreatedTimestamp = _createdTimestamp; + status.PendingApplicationMessagesCount = ApplicationMessagesQueue.Count; + } + } + + public class MqttClientConnection : IMqttClientSession, IDisposable + { + private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); + private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); + private readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); + + private readonly MqttRetainedMessagesManager _retainedMessagesManager; + private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; + private readonly MqttClientSessionsManager _sessionsManager; + + private readonly IMqttNetChildLogger _logger; + private readonly IMqttServerOptions _serverOptions; + + private Task _packageReceiverTask; + private readonly IMqttChannelAdapter _channelAdapter; + private readonly IMqttDataConverter _dataConverter; + private readonly string _endpoint; + private readonly MqttConnectPacket _connectPacket; + + private DateTime _lastPacketReceivedTimestamp; + private long _receivedPacketsCount; + private long _sentPacketsCount; + private long _receivedApplicationMessagesCount; + private long _sentApplicationMessagesCount; + + public MqttClientConnection( + MqttConnectPacket connectPacket, + IMqttChannelAdapter channelAdapter, + MqttClientSession session, + IMqttServerOptions serverOptions, + MqttClientSessionsManager sessionsManager, + MqttRetainedMessagesManager retainedMessagesManager, + IMqttNetChildLogger logger) + { + Session = session ?? throw new ArgumentNullException(nameof(session)); + _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); + _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); + _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); + + _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); + _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; + _endpoint = _channelAdapter.Endpoint; + _connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); + + if (logger == null) throw new ArgumentNullException(nameof(logger)); + _logger = logger.CreateChildLogger(nameof(MqttClientConnection)); + + _keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger); + + _lastPacketReceivedTimestamp = DateTime.UtcNow; + } + + public string ClientId => _connectPacket.ClientId; + + public MqttClientSession Session { get; } + + public async Task StopAsync() + { + StopInternal(); + + var task = _packageReceiverTask; + if (task != null && !task.IsCompleted) + { + await task.ConfigureAwait(false); + } + } + + public void FillStatus(MqttClientStatus status) + { + status.ClientId = ClientId; + status.Endpoint = _endpoint; + status.ProtocolVersion = _channelAdapter.PacketFormatterAdapter.ProtocolVersion.Value; + + status.ReceivedApplicationMessagesCount = Interlocked.Read(ref _receivedApplicationMessagesCount); + status.SentApplicationMessagesCount = Interlocked.Read(ref _sentApplicationMessagesCount); + + status.ReceivedPacketsCount = Interlocked.Read(ref _receivedPacketsCount); + status.SentPacketsCount = Interlocked.Read(ref _sentPacketsCount); + + status.LastPacketReceivedTimestamp = _lastPacketReceivedTimestamp; + + //status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; + } + + //public void ClearPendingApplicationMessages() + //{ + // Session.ApplicationMessagesQueue.Clear(); + + // //_applicationMessagesQueue.Clear(); + //} + + public void Dispose() + { + _cancellationToken.Dispose(); + } + + public Task RunAsync() + { + _packageReceiverTask = RunInternalAsync(); + return _packageReceiverTask; + } + + public void EnqueueApplicationMessage(MqttApplicationMessage applicationMessage, string senderClientId, bool isRetainedApplicationMessage) + { + if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + + Session.EnqueueApplicationMessage(applicationMessage, senderClientId, isRetainedApplicationMessage); + + _logger.Verbose("Enqueued application message (ClientId: {0}).", ClientId); + } + + private async Task RunInternalAsync() + { + var disconnectType = MqttClientDisconnectType.NotClean; + try + { + _logger.Info("Client '{0}': Session started.", ClientId); + //_eventDispatcher.OnClientConnected(ClientId); + + _channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted; + _channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted; + + Session.WillMessage = _connectPacket.WillMessage; + + Task.Run(() => SendPendingPacketsAsync(), _cancellationToken.Token); + + // TODO: Change to single thread in SessionManager. Or use SessionManager and stats from KeepAliveMonitor. + _keepAliveMonitor.Start(_connectPacket.KeepAlivePeriod, _cancellationToken.Token); + + await SendAsync( + new MqttConnAckPacket + { + ReturnCode = MqttConnectReturnCode.ConnectionAccepted, + ReasonCode = MqttConnectReasonCode.Success, + IsSessionPresent = Session.IsCleanSession + }).ConfigureAwait(false); + + Session.IsCleanSession = false; + + while (!_cancellationToken.IsCancellationRequested) + { + var packet = await _channelAdapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationToken.Token).ConfigureAwait(false); + if (packet == null) + { + // The client has closed the connection gracefully. + break; + } + + Interlocked.Increment(ref _sentPacketsCount); + _lastPacketReceivedTimestamp = DateTime.UtcNow; + + _keepAliveMonitor.PacketReceived(); + + if (packet is MqttPublishPacket publishPacket) + { + await HandleIncomingPublishPacketAsync(publishPacket).ConfigureAwait(false); + continue; + } + + if (packet is MqttPubRelPacket pubRelPacket) + { + var pubCompPacket = new MqttPubCompPacket + { + PacketIdentifier = pubRelPacket.PacketIdentifier, + ReasonCode = MqttPubCompReasonCode.Success + }; + + await SendAsync(pubCompPacket).ConfigureAwait(false); + continue; + } + + if (packet is MqttSubscribePacket subscribePacket) + { + await HandleIncomingSubscribePacketAsync(subscribePacket).ConfigureAwait(false); + continue; + } + + if (packet is MqttUnsubscribePacket unsubscribePacket) + { + await HandleIncomingUnsubscribePacketAsync(unsubscribePacket).ConfigureAwait(false); + continue; + } + + if (packet is MqttPingReqPacket) + { + await SendAsync(new MqttPingRespPacket()).ConfigureAwait(false); + continue; + } + + if (packet is MqttDisconnectPacket) + { + Session.WillMessage = null; + disconnectType = MqttClientDisconnectType.Clean; + + StopInternal(); + break; + } + + _packetDispatcher.Dispatch(packet); + } + } + catch (OperationCanceledException) + { + } + catch (Exception exception) + { + if (exception is MqttCommunicationException) + { + _logger.Warning(exception, "Client '{0}': Communication exception while receiving client packets.", ClientId); + } + else + { + _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); + } + + StopInternal(); + } + finally + { + if (Session.WillMessage != null) + { + _sessionsManager.DispatchApplicationMessage(Session.WillMessage, this); + Session.WillMessage = null; + } + + _packetDispatcher.Reset(); + + _channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; + _channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; + + _logger.Info("Client '{0}': Session stopped.", ClientId); + //_eventDispatcher.OnClientDisconnected(ClientId); + + _packageReceiverTask = null; + } + + return disconnectType; + } + + private void StopInternal() + { + _cancellationToken.Cancel(false); + } + + private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) + { + var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); + foreach (var applicationMessage in retainedMessages) + { + EnqueueApplicationMessage(applicationMessage, ClientId, true); + } + } + + private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) + { + // TODO: Let the channel adapter create the packet. + var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); + + await SendAsync(subscribeResult.ResponsePacket).ConfigureAwait(false); + + if (subscribeResult.CloseConnection) + { + StopInternal(); + return; + } + + await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); + } + + private async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket) + { + // TODO: Let the channel adapter create the packet. + var unsubscribeResult = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); + await SendAsync(unsubscribeResult).ConfigureAwait(false); + } + + private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) + { + Interlocked.Increment(ref _sentApplicationMessagesCount); + + switch (publishPacket.QualityOfServiceLevel) + { + case MqttQualityOfServiceLevel.AtMostOnce: + { + return HandleIncomingPublishPacketWithQoS0Async(publishPacket); + } + case MqttQualityOfServiceLevel.AtLeastOnce: + { + return HandleIncomingPublishPacketWithQoS1Async(publishPacket); + } + case MqttQualityOfServiceLevel.ExactlyOnce: + { + return HandleIncomingPublishPacketWithQoS2Async(publishPacket); + } + default: + { + throw new MqttCommunicationException("Received a not supported QoS level."); + } + } + } + + private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket) + { + var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); + + _sessionsManager.DispatchApplicationMessage(applicationMessage, this); + + return Task.FromResult(0); + } + + private Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) + { + var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); + _sessionsManager.DispatchApplicationMessage(applicationMessage, this); + + // TODO: Create ACK packet. + + var response = new MqttPubAckPacket + { + PacketIdentifier = publishPacket.PacketIdentifier, + ReasonCode = MqttPubAckReasonCode.Success + }; + + return SendAsync(response); + } + + private async Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket) + { + var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); + _sessionsManager.DispatchApplicationMessage(applicationMessage, this); + + var pubRecPacket = new MqttPubRecPacket + { + PacketIdentifier = publishPacket.PacketIdentifier, + ReasonCode = MqttPubRecReasonCode.Success + }; + + await SendAsync(pubRecPacket).ConfigureAwait(false); + + ////Task.Run(async () => + ////{ + //// using (var pubRelPacketAwaiter = _packetDispatcher.AddPacketAwaiter(publishPacket.PacketIdentifier)) + //// { + //// await pubRelPacketAwaiter.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + + //// var pubCompPacket = new MqttPubCompPacket + //// { + //// PacketIdentifier = publishPacket.PacketIdentifier, + //// ReasonCode = MqttPubCompReasonCode.Success + //// }; + + //// await SendAsync(pubCompPacket).ConfigureAwait(false); + //// } + ////}); + + await Task.FromResult(0); + } + + private async Task SendPendingPacketsAsync() + { + MqttPendingApplicationMessage enqueuedApplicationMessage = null; + MqttPublishPacket publishPacket = null; + + try + { + while (!_cancellationToken.IsCancellationRequested) + { + enqueuedApplicationMessage = await Session.ApplicationMessagesQueue.TakeAsync(_cancellationToken.Token).ConfigureAwait(false); + if (enqueuedApplicationMessage == null) + { + return; + } + + if (_cancellationToken.IsCancellationRequested) + { + return; + } + + publishPacket = _dataConverter.CreatePublishPacket(enqueuedApplicationMessage.ApplicationMessage); + publishPacket.QualityOfServiceLevel = enqueuedApplicationMessage.QualityOfServiceLevel; + + // Set the retain flag to true according to [MQTT-3.3.1-8] and [MQTT-3.3.1-9]. + publishPacket.Retain = enqueuedApplicationMessage.IsRetainedMessage; + + if (publishPacket.QualityOfServiceLevel > 0) + { + publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); + } + + if (_serverOptions.ClientMessageQueueInterceptor != null) + { + var context = new MqttClientMessageQueueInterceptorContext( + enqueuedApplicationMessage.SenderClientId, + ClientId, + enqueuedApplicationMessage.ApplicationMessage); + + if (_serverOptions.ClientMessageQueueInterceptor != null) + { + await _serverOptions.ClientMessageQueueInterceptor.InterceptClientMessageQueueEnqueueAsync(context).ConfigureAwait(false); + } + + if (!context.AcceptEnqueue || context.ApplicationMessage == null) + { + return; + } + + publishPacket.Topic = context.ApplicationMessage.Topic; + publishPacket.Payload = context.ApplicationMessage.Payload; + publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel; + } + + if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) + { + await SendAsync(publishPacket).ConfigureAwait(false); + } + else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) + { + var awaiter = _packetDispatcher.AddPacketAwaiter(publishPacket.PacketIdentifier); + await SendAsync(publishPacket).ConfigureAwait(false); + await awaiter.WaitOneAsync(_serverOptions.DefaultCommunicationTimeout).ConfigureAwait(false); + } + else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) + { + using (var awaiter1 = _packetDispatcher.AddPacketAwaiter(publishPacket.PacketIdentifier)) + using (var awaiter2 = _packetDispatcher.AddPacketAwaiter(publishPacket.PacketIdentifier)) + { + await SendAsync(publishPacket).ConfigureAwait(false); + await awaiter1.WaitOneAsync(_serverOptions.DefaultCommunicationTimeout).ConfigureAwait(false); + + await SendAsync(new MqttPubRelPacket { PacketIdentifier = publishPacket.PacketIdentifier }).ConfigureAwait(false); + await awaiter2.WaitOneAsync(_serverOptions.DefaultCommunicationTimeout).ConfigureAwait(false); + } + } + + _logger.Verbose("Enqueued application message sent (ClientId: {0}).", ClientId); + + // TODO: + //Interlocked.Increment(ref _sentPacketsCount); + } + } + catch (Exception exception) + { + if (exception is MqttCommunicationTimedOutException) + { + _logger.Warning(exception, "Sending publish packet failed: Timeout (ClientId: {0}).", ClientId); + } + else if (exception is MqttCommunicationException) + { + _logger.Warning(exception, "Sending publish packet failed: Communication exception (ClientId: {0}).", ClientId); + } + else if (exception is OperationCanceledException && _cancellationToken.Token.IsCancellationRequested) + { + // The cancellation was triggered externally. + } + else + { + _logger.Error(exception, "Sending publish packet failed (ClientId: {0}).", ClientId); + } + + if (publishPacket?.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) + { + enqueuedApplicationMessage.IsDuplicate = true; + + Session.ApplicationMessagesQueue.Enqueue(enqueuedApplicationMessage); + } + + if (!_cancellationToken.Token.IsCancellationRequested) + { + await StopAsync().ConfigureAwait(false); + } + } + } + + private async Task SendAsync(MqttBasePacket packet) + { + await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, _cancellationToken.Token).ConfigureAwait(false); + + Interlocked.Increment(ref _receivedPacketsCount); + + if (packet is MqttPublishPacket) + { + Interlocked.Increment(ref _receivedApplicationMessagesCount); + } + } + + private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) + { + _keepAliveMonitor?.Resume(); + } + + private void OnAdapterReadingPacketStarted(object sender, EventArgs e) + { + _keepAliveMonitor?.Pause(); + } + } +} diff --git a/Source/MQTTnet/Server/MqttClientDisconnectType.cs b/Source/MQTTnet/Server/MqttClientDisconnectType.cs index 19e9da8..c4d6f59 100644 --- a/Source/MQTTnet/Server/MqttClientDisconnectType.cs +++ b/Source/MQTTnet/Server/MqttClientDisconnectType.cs @@ -3,6 +3,7 @@ public enum MqttClientDisconnectType { Clean, - NotClean + NotClean, + Takeover } } diff --git a/Source/MQTTnet/Server/MqttClientDisconnectedEventArgs.cs b/Source/MQTTnet/Server/MqttClientDisconnectedEventArgs.cs index 7e88548..225d025 100644 --- a/Source/MQTTnet/Server/MqttClientDisconnectedEventArgs.cs +++ b/Source/MQTTnet/Server/MqttClientDisconnectedEventArgs.cs @@ -4,14 +4,14 @@ namespace MQTTnet.Server { public class MqttClientDisconnectedEventArgs : EventArgs { - public MqttClientDisconnectedEventArgs(string clientId, bool wasCleanDisconnect) + public MqttClientDisconnectedEventArgs(string clientId, MqttClientDisconnectType disconnectType) { ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); - WasCleanDisconnect = wasCleanDisconnect; + DisconnectType = disconnectType; } public string ClientId { get; } - public bool WasCleanDisconnect { get; } + public MqttClientDisconnectType DisconnectType { get; } } } diff --git a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs index e8e7c15..3f89d19 100644 --- a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs +++ b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs @@ -10,7 +10,6 @@ namespace MQTTnet.Server public class MqttClientKeepAliveMonitor { private readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch(); - private readonly Stopwatch _lastNonKeepAlivePacketReceivedTracker = new Stopwatch(); private readonly IMqttClientSession _clientSession; private readonly IMqttNetChildLogger _logger; @@ -26,10 +25,6 @@ namespace MQTTnet.Server _logger = logger.CreateChildLogger(nameof(MqttClientKeepAliveMonitor)); } - public TimeSpan LastPacketReceived => _lastPacketReceivedTracker.Elapsed; - - public TimeSpan LastNonKeepAlivePacketReceived => _lastNonKeepAlivePacketReceivedTracker.Elapsed; - public void Start(int keepAlivePeriod, CancellationToken cancellationToken) { if (keepAlivePeriod == 0) @@ -50,20 +45,9 @@ namespace MQTTnet.Server _isPaused = false; } - public void Reset() + public void PacketReceived() { _lastPacketReceivedTracker.Restart(); - _lastNonKeepAlivePacketReceivedTracker.Restart(); - } - - public void PacketReceived(MqttBasePacket packet) - { - _lastPacketReceivedTracker.Restart(); - - if (!(packet is MqttPingReqPacket)) - { - _lastNonKeepAlivePacketReceivedTracker.Restart(); - } } private async Task RunAsync(int keepAlivePeriod, CancellationToken cancellationToken) @@ -71,7 +55,6 @@ namespace MQTTnet.Server try { _lastPacketReceivedTracker.Restart(); - _lastNonKeepAlivePacketReceivedTracker.Restart(); while (!cancellationToken.IsCancellationRequested) { @@ -81,7 +64,7 @@ namespace MQTTnet.Server if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds >= keepAlivePeriod * 1.5D) { _logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientSession.ClientId); - await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); + await _clientSession.StopAsync().ConfigureAwait(false); return; } diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs deleted file mode 100644 index ad6fdfa..0000000 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ /dev/null @@ -1,481 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Adapter; -using MQTTnet.Client; -using MQTTnet.Diagnostics; -using MQTTnet.Exceptions; -using MQTTnet.MessageStream; -using MQTTnet.PacketDispatcher; -using MQTTnet.Packets; -using MQTTnet.Protocol; - -namespace MQTTnet.Server -{ - public class MqttClientSession : IMqttClientSession - { - private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); - private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); - - private readonly MqttRetainedMessagesManager _retainedMessagesManager; - private readonly MqttServerEventDispatcher _eventDispatcher; - private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; - private readonly MqttClientSessionPendingMessagesQueue _pendingMessagesQueue; - private readonly MqttClientSubscriptionsManager _subscriptionsManager; - private readonly MqttClientSessionsManager _sessionsManager; - - private readonly IMqttNetChildLogger _logger; - private readonly IMqttServerOptions _options; - - private CancellationTokenSource _cancellationTokenSource; - private MqttApplicationMessage _willMessage; - private bool _wasCleanDisconnect; - private Task _workerTask; - private IMqttChannelAdapter _channelAdapter; - - private long _receivedMessagesCount; - private bool _isCleanSession = true; - - public MqttClientSession( - string clientId, - IMqttServerOptions options, - MqttClientSessionsManager sessionsManager, - MqttRetainedMessagesManager retainedMessagesManager, - MqttServerEventDispatcher eventDispatcher, - IMqttNetChildLogger logger) - { - if (logger == null) throw new ArgumentNullException(nameof(logger)); - - _options = options ?? throw new ArgumentNullException(nameof(options)); - _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); - _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); - _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher)); - - ClientId = clientId; - - _logger = logger.CreateChildLogger(nameof(MqttClientSession)); - - _keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger); - _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, eventDispatcher); - _pendingMessagesQueue = new MqttClientSessionPendingMessagesQueue(_options, this, _packetDispatcher, _logger); - } - - public string ClientId { get; } - - public void FillStatus(MqttClientSessionStatus status) - { - status.ClientId = ClientId; - status.IsConnected = _cancellationTokenSource != null; - status.Endpoint = _channelAdapter?.Endpoint; - status.ProtocolVersion = _channelAdapter?.PacketFormatterAdapter?.ProtocolVersion; - status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count; - status.ReceivedApplicationMessagesCount = _pendingMessagesQueue.SentMessagesCount; - status.SentApplicationMessagesCount = Interlocked.Read(ref _receivedMessagesCount); - status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived; - status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; - } - - public async Task StopAsync(MqttClientDisconnectType type) - { - StopInternal(type); - - var task = _workerTask; - if (task != null && !task.IsCompleted) - { - await task.ConfigureAwait(false); - } - } - - public async Task SubscribeAsync(IEnumerable topicFilters) - { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - var topicFiltersCollection = topicFilters.ToList(); - - var packet = new MqttSubscribePacket(); - packet.TopicFilters.AddRange(topicFiltersCollection); - - await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false); - await EnqueueSubscribedRetainedMessagesAsync(topicFiltersCollection).ConfigureAwait(false); - } - - public Task UnsubscribeAsync(IEnumerable topicFilters) - { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - var packet = new MqttUnsubscribePacket(); - packet.TopicFilters.AddRange(topicFilters); - - _subscriptionsManager.Unsubscribe(packet); - - return Task.FromResult(0); - } - - public void ClearPendingApplicationMessages() - { - _pendingMessagesQueue.Clear(); - } - - public void Dispose() - { - _pendingMessagesQueue?.Dispose(); - - _cancellationTokenSource?.Cancel(); - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; - } - - public Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) - { - if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); - if (channelAdapter == null) throw new ArgumentNullException(nameof(channelAdapter)); - - _workerTask = RunInternalAsync(connectPacket, channelAdapter); - return _workerTask; - } - - public async Task EnqueueApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage, bool isRetainedApplicationMessage) - { - if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - - var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(applicationMessage.Topic, applicationMessage.QualityOfServiceLevel); - if (!checkSubscriptionsResult.IsSubscribed) - { - return; - } - - var publishPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreatePublishPacket(applicationMessage); - publishPacket.QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel; - - // Set the retain flag to true according to [MQTT-3.3.1-8] and [MQTT-3.3.1-9]. - publishPacket.Retain = isRetainedApplicationMessage; - - if (publishPacket.QualityOfServiceLevel > 0) - { - publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - } - - if (_options.ClientMessageQueueInterceptor != null) - { - var context = new MqttClientMessageQueueInterceptorContext( - senderClientSession?.ClientId, - ClientId, - applicationMessage); - - if (_options.ClientMessageQueueInterceptor != null) - { - await _options.ClientMessageQueueInterceptor.InterceptClientMessageQueueEnqueueAsync(context).ConfigureAwait(false); - } - - if (!context.AcceptEnqueue || context.ApplicationMessage == null) - { - return; - } - - publishPacket.Topic = context.ApplicationMessage.Topic; - publishPacket.Payload = context.ApplicationMessage.Payload; - publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel; - } - - _pendingMessagesQueue.Enqueue(publishPacket); - } - - - - private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) - { - if (channelAdapter == null) throw new ArgumentNullException(nameof(channelAdapter)); - - try - { - _logger.Info("Client '{0}': Connected.", ClientId); - _eventDispatcher.OnClientConnected(ClientId); - - _channelAdapter = channelAdapter; - - _channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted; - _channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted; - - var cancellationTokenSource = new CancellationTokenSource(); - _cancellationTokenSource = cancellationTokenSource; - - _wasCleanDisconnect = false; - _willMessage = connectPacket.WillMessage; - - _pendingMessagesQueue.Start(channelAdapter, cancellationTokenSource.Token); - _keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, cancellationTokenSource.Token); - - await channelAdapter.SendPacketAsync( - new MqttConnAckPacket - { - ReturnCode = MqttConnectReturnCode.ConnectionAccepted, - ReasonCode = MqttConnectReasonCode.Success, - IsSessionPresent = _isCleanSession - }, - cancellationTokenSource.Token).ConfigureAwait(false); - - _isCleanSession = false; - - Task.Run(async () => - { - while (!cancellationTokenSource.IsCancellationRequested) - { - var packet = _packetDispatcher.Take(cancellationTokenSource.Token); - await ProcessReceivedPacketAsync(packet, cancellationTokenSource.Token).ConfigureAwait(false); - } - }, cancellationTokenSource.Token); - - Task.Run(async () => - { - while (!cancellationTokenSource.IsCancellationRequested) - { - try - { - var packet = await _outboundMessageStream.TakeAsync(cancellationTokenSource.Token); - await channelAdapter.SendPacketAsync(packet, cancellationTokenSource.Token); - } - catch (Exception e) - { - _logger.Error(e, "sdfsdf"); - await StopAsync(MqttClientDisconnectType.NotClean); - - } - - } - },cancellationTokenSource.Token); - - while (!cancellationTokenSource.IsCancellationRequested) - { - var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, cancellationTokenSource.Token).ConfigureAwait(false); - if (packet != null) - { - _keepAliveMonitor.PacketReceived(packet); - _packetDispatcher.Dispatch(packet); - } - } - } - catch (OperationCanceledException) - { - } - catch (Exception exception) - { - if (exception is MqttCommunicationException) - { - _logger.Warning(exception, "Client '{0}': Communication exception while receiving client packets.", ClientId); - } - else - { - _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); - } - - StopInternal(MqttClientDisconnectType.NotClean); - } - finally - { - if (_willMessage != null && !_wasCleanDisconnect) - { - _sessionsManager.EnqueueApplicationMessage(this, _willMessage); - } - - _willMessage = null; - - _channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; - _channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; - _channelAdapter = null; - - _logger.Info("Client '{0}': Session stopped.", ClientId); - _eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect); - - _workerTask = null; - } - } - - private void StopInternal(MqttClientDisconnectType type) - { - var cts = _cancellationTokenSource; - if (cts == null || cts.IsCancellationRequested) - { - return; - } - - _wasCleanDisconnect = type == MqttClientDisconnectType.Clean; - _cancellationTokenSource?.Cancel(false); - _packetDispatcher.Reset(); - } - - private readonly MqttMessageStream _outboundMessageStream = new MqttMessageStream(); - - private Task ProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) - { - if (packet is MqttPublishPacket publishPacket) - { - return HandleIncomingPublishPacketAsync(publishPacket, cancellationToken); - } - - if (packet is MqttPingReqPacket) - { - //return channelAdapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); - _outboundMessageStream.Enqueue(new MqttPingRespPacket()); - return Task.FromResult(0); - } - - if (packet is MqttSubscribePacket subscribePacket) - { - return HandleIncomingSubscribePacketAsync(subscribePacket, cancellationToken); - } - - if (packet is MqttUnsubscribePacket unsubscribePacket) - { - return HandleIncomingUnsubscribePacketAsync(unsubscribePacket, cancellationToken); - } - - if (packet is MqttDisconnectPacket) - { - StopInternal(MqttClientDisconnectType.Clean); - return Task.FromResult(0); - } - - //if (packet is MqttAuthPacket || - // packet is MqttSubAckPacket || - // packet is MqttUnsubAckPacket || - // packet is MqttPubAckPacket || - // packet is MqttPubCompPacket || - // packet is MqttPubRecPacket || - // packet is MqttPubRelPacket) - //{ - // _packetDispatcher.TryDispatch(packet); - // return Task.FromResult(0); - //} - - _logger.Warning(null, "Client '{0}': Received invalid packet ({1}). Closing connection.", ClientId, packet); - - StopInternal(MqttClientDisconnectType.NotClean); - return Task.FromResult(0); - } - - private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) - { - var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); - foreach (var applicationMessage in retainedMessages) - { - await EnqueueApplicationMessageAsync(null, applicationMessage, true).ConfigureAwait(false); - } - } - - private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) - { - var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); - - _outboundMessageStream.Enqueue(subscribeResult.ResponsePacket); - //await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); - - // TODO: Add "WaitForDelivery". - - if (subscribeResult.CloseConnection) - { - StopInternal(MqttClientDisconnectType.NotClean); - return; - } - - await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); - } - - private Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) - { - var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); - - _outboundMessageStream.Enqueue(unsubscribeResult); - - //return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); - return Task.FromResult(0); - } - - private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) - { - Interlocked.Increment(ref _receivedMessagesCount); - - switch (publishPacket.QualityOfServiceLevel) - { - case MqttQualityOfServiceLevel.AtMostOnce: - { - return HandleIncomingPublishPacketWithQoS0Async(publishPacket); - } - case MqttQualityOfServiceLevel.AtLeastOnce: - { - return HandleIncomingPublishPacketWithQoS1Async(publishPacket, cancellationToken); - } - case MqttQualityOfServiceLevel.ExactlyOnce: - { - return HandleIncomingPublishPacketWithQoS2Async(publishPacket, cancellationToken); - } - default: - { - throw new MqttCommunicationException("Received a not supported QoS level."); - } - } - } - - private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket) - { - _sessionsManager.EnqueueApplicationMessage( - this, - _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); - - return Task.FromResult(0); - } - - private Task HandleIncomingPublishPacketWithQoS1Async( - MqttPublishPacket publishPacket, - CancellationToken cancellationToken) - { - _sessionsManager.EnqueueApplicationMessage( - this, - _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); - - var response = new MqttPubAckPacket - { - PacketIdentifier = publishPacket.PacketIdentifier, - ReasonCode = MqttPubAckReasonCode.Success - }; - - _outboundMessageStream.Enqueue(response); - - //return adapter.SendPacketAsync(response, cancellationToken); - return Task.FromResult(0); - } - - private async Task HandleIncomingPublishPacketWithQoS2Async( - MqttPublishPacket publishPacket, - CancellationToken cancellationToken) - { - // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) - _sessionsManager.EnqueueApplicationMessage(this, _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); - - using (var pubRelPacketAwaiter = _packetDispatcher.AddPacketAwaiter(publishPacket.PacketIdentifier)) - { - var pubRecPacket = new MqttPubRecPacket - { - PacketIdentifier = publishPacket.PacketIdentifier, - ReasonCode = MqttPubRecReasonCode.Success - }; - - //await adapter.SendPacketAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); - _outboundMessageStream.Enqueue(pubRecPacket); - - await pubRelPacketAwaiter.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); - } - } - - private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) - { - _keepAliveMonitor?.Resume(); - } - - private void OnAdapterReadingPacketStarted(object sender, EventArgs e) - { - _keepAliveMonitor?.Pause(); - } - } -} diff --git a/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs b/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs new file mode 100644 index 0000000..7fbd9e2 --- /dev/null +++ b/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs @@ -0,0 +1,117 @@ +using MQTTnet.Internal; +using MQTTnet.Protocol; +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public class MqttPendingApplicationMessage + { + public MqttApplicationMessage ApplicationMessage { get; set; } + + public string SenderClientId { get; set; } + + public bool IsRetainedMessage { get; set; } + + public MqttQualityOfServiceLevel QualityOfServiceLevel { get; set; } + + public bool IsDuplicate { get; set; } + } + + public class MqttClientSessionApplicationMessagesQueue : IDisposable + { + private readonly Queue _messageQueue = new Queue(); + private readonly AsyncAutoResetEvent _messageQueueLock = new AsyncAutoResetEvent(); + + private readonly IMqttServerOptions _options; + + public MqttClientSessionApplicationMessagesQueue(IMqttServerOptions options) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + + } + + public int Count + { + get + { + lock (_messageQueue) + { + return _messageQueue.Count; + } + } + } + + public void Enqueue(MqttApplicationMessage applicationMessage, string senderClientId, MqttQualityOfServiceLevel qualityOfServiceLevel, bool isRetainedMessage) + { + if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + + Enqueue(new MqttPendingApplicationMessage + { + ApplicationMessage = applicationMessage, + SenderClientId = senderClientId, + QualityOfServiceLevel = qualityOfServiceLevel, + IsRetainedMessage = isRetainedMessage + }); + } + + public void Clear() + { + lock (_messageQueue) + { + _messageQueue.Clear(); + } + } + + public void Dispose() + { + } + + public async Task TakeAsync(CancellationToken cancellationToken) + { + // TODO: Create a blocking queue from this. + + while (!cancellationToken.IsCancellationRequested) + { + lock (_messageQueue) + { + if (_messageQueue.Count > 0) + { + return _messageQueue.Dequeue(); + } + } + + await _messageQueueLock.WaitOneAsync(cancellationToken).ConfigureAwait(false); + } + + return null; + } + + public void Enqueue(MqttPendingApplicationMessage enqueuedApplicationMessage) + { + if (enqueuedApplicationMessage == null) throw new ArgumentNullException(nameof(enqueuedApplicationMessage)); + + lock (_messageQueue) + { + if (_messageQueue.Count >= _options.MaxPendingMessagesPerClient) + { + if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropNewMessage) + { + return; + } + + if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage) + { + _messageQueue.Dequeue(); + } + } + + _messageQueue.Enqueue(enqueuedApplicationMessage); + } + + _messageQueueLock.Set(); + } + } +} diff --git a/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs b/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs deleted file mode 100644 index 1fd19e8..0000000 --- a/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs +++ /dev/null @@ -1,212 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Adapter; -using MQTTnet.Diagnostics; -using MQTTnet.Exceptions; -using MQTTnet.Internal; -using MQTTnet.PacketDispatcher; -using MQTTnet.Packets; -using MQTTnet.Protocol; - -namespace MQTTnet.Server -{ - public class MqttClientSessionPendingMessagesQueue : IDisposable - { - private readonly Queue _queue = new Queue(); - private readonly AsyncAutoResetEvent _queueLock = new AsyncAutoResetEvent(); - - private readonly IMqttServerOptions _options; - private readonly MqttClientSession _clientSession; - private readonly MqttPacketDispatcher _packetDispatcher; - private readonly IMqttNetChildLogger _logger; - - private long _sentPacketsCount; - - public MqttClientSessionPendingMessagesQueue( - IMqttServerOptions options, - MqttClientSession clientSession, - MqttPacketDispatcher packetDispatcher, - IMqttNetChildLogger logger) - { - if (logger == null) throw new ArgumentNullException(nameof(logger)); - _options = options ?? throw new ArgumentNullException(nameof(options)); - _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); - _packetDispatcher = packetDispatcher ?? throw new ArgumentNullException(nameof(packetDispatcher)); - - _logger = logger.CreateChildLogger(nameof(MqttClientSessionPendingMessagesQueue)); - } - - public int Count - { - get - { - lock (_queue) - { - return _queue.Count; - } - } - } - - public long SentMessagesCount => Interlocked.Read(ref _sentPacketsCount); - - public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken) - { - if (adapter == null) throw new ArgumentNullException(nameof(adapter)); - - if (cancellationToken.IsCancellationRequested) - { - return; - } - - Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken); - } - - public void Enqueue(MqttPublishPacket packet) - { - if (packet == null) throw new ArgumentNullException(nameof(packet)); - - lock (_queue) - { - if (_queue.Count >= _options.MaxPendingMessagesPerClient) - { - if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropNewMessage) - { - return; - } - - if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage) - { - _queue.Dequeue(); - } - } - - _queue.Enqueue(packet); - } - - _queueLock.Set(); - - _logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId); - } - - public void Clear() - { - lock (_queue) - { - _queue.Clear(); - } - } - - public void Dispose() - { - } - - private async Task SendQueuedPacketsAsync(IMqttChannelAdapter adapter, CancellationToken cancellationToken) - { - try - { - while (!cancellationToken.IsCancellationRequested) - { - await TrySendNextQueuedPacketAsync(adapter, cancellationToken).ConfigureAwait(false); - } - } - catch (OperationCanceledException) - { - } - catch (Exception exception) - { - _logger.Error(exception, "Unhandled exception while sending enqueued packet (ClientId: {0}).", _clientSession.ClientId); - } - } - - private async Task TrySendNextQueuedPacketAsync(IMqttChannelAdapter adapter, CancellationToken cancellationToken) - { - MqttPublishPacket packet = null; - try - { - if (cancellationToken.IsCancellationRequested) - { - return; - } - - lock (_queue) - { - if (_queue.Count > 0) - { - packet = _queue.Dequeue(); - } - } - - if (packet == null) - { - await _queueLock.WaitOneAsync(cancellationToken).ConfigureAwait(false); - return; - } - - if (packet.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) - { - await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); - } - else if (packet.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) - { - var awaiter = _packetDispatcher.AddPacketAwaiter(packet.PacketIdentifier); - await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); - await awaiter.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); - } - else if (packet.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) - { - var awaiter1 = _packetDispatcher.AddPacketAwaiter(packet.PacketIdentifier); - var awaiter2 = _packetDispatcher.AddPacketAwaiter(packet.PacketIdentifier); - try - { - await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); - await awaiter1.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); - - await adapter.SendPacketAsync(new MqttPubRelPacket { PacketIdentifier = packet.PacketIdentifier }, cancellationToken).ConfigureAwait(false); - await awaiter2.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); - } - finally - { - _packetDispatcher.RemovePacketAwaiter(packet.PacketIdentifier); - _packetDispatcher.RemovePacketAwaiter(packet.PacketIdentifier); - } - } - - _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); - - Interlocked.Increment(ref _sentPacketsCount); - } - catch (Exception exception) - { - if (exception is MqttCommunicationTimedOutException) - { - _logger.Warning(exception, "Sending publish packet failed: Timeout (ClientId: {0}).", _clientSession.ClientId); - } - else if (exception is MqttCommunicationException) - { - _logger.Warning(exception, "Sending publish packet failed: Communication exception (ClientId: {0}).", _clientSession.ClientId); - } - else if (exception is OperationCanceledException && cancellationToken.IsCancellationRequested) - { - } - else - { - _logger.Error(exception, "Sending publish packet failed (ClientId: {0}).", _clientSession.ClientId); - } - - if (packet?.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) - { - packet.Dup = true; - - Enqueue(packet); - } - - if (!cancellationToken.IsCancellationRequested) - { - await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); - } - } - } - } -} diff --git a/Source/MQTTnet/Server/MqttClientSessionStatus.cs b/Source/MQTTnet/Server/MqttClientSessionStatus.cs deleted file mode 100644 index 007a19b..0000000 --- a/Source/MQTTnet/Server/MqttClientSessionStatus.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Threading.Tasks; -using MQTTnet.Formatter; - -namespace MQTTnet.Server -{ - public class MqttClientSessionStatus : IMqttClientSessionStatus - { - private readonly MqttClientSessionsManager _sessionsManager; - private readonly MqttClientSession _session; - - public MqttClientSessionStatus(MqttClientSessionsManager sessionsManager, MqttClientSession session) - { - _sessionsManager = sessionsManager; - _session = session; - } - - public string ClientId { get; set; } - public string Endpoint { get; set; } - public bool IsConnected { get; set; } - public MqttProtocolVersion? ProtocolVersion { get; set; } - public TimeSpan LastPacketReceived { get; set; } - public TimeSpan LastNonKeepAlivePacketReceived { get; set; } - public long PendingApplicationMessagesCount { get; set; } - public long ReceivedApplicationMessagesCount { get; set; } - public long SentApplicationMessagesCount { get; set; } - - public Task DisconnectAsync() - { - return _session.StopAsync(MqttClientDisconnectType.NotClean); - } - - public async Task DeleteSessionAsync() - { - try - { - await _session.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); - } - finally - { - await _sessionsManager.DeleteSessionAsync(ClientId).ConfigureAwait(false); - } - } - - public Task ClearPendingApplicationMessagesAsync() - { - _session.ClearPendingApplicationMessages(); - - return Task.FromResult(0); - } - } -} diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 35379ce..cf58ec7 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -1,14 +1,13 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; -using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; +using MQTTnet.Server.Status; namespace MQTTnet.Server { @@ -16,9 +15,9 @@ namespace MQTTnet.Server { private readonly BlockingCollection _messageQueue = new BlockingCollection(); - private readonly AsyncLock _sessionsLock = new AsyncLock(); - private readonly Dictionary _sessions = new Dictionary(); - + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); + private readonly CancellationToken _cancellationToken; private readonly MqttServerEventDispatcher _eventDispatcher; @@ -27,8 +26,8 @@ namespace MQTTnet.Server private readonly IMqttNetChildLogger _logger; public MqttClientSessionsManager( - IMqttServerOptions options, - MqttRetainedMessagesManager retainedMessagesManager, + IMqttServerOptions options, + MqttRetainedMessagesManager retainedMessagesManager, CancellationToken cancellationToken, MqttServerEventDispatcher eventDispatcher, IMqttNetChildLogger logger) @@ -45,98 +44,100 @@ namespace MQTTnet.Server public void Start() { - Task.Factory.StartNew(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); + Task.Run(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken); } public async Task StopAsync() { - List sessions; - using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) - { - sessions = _sessions.Values.ToList(); - } + //using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) - foreach (var session in sessions) + foreach (var connection in _connections.Values) { - await session.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); + await connection.StopAsync().ConfigureAwait(false); } } public Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter) { return HandleConnectionAsync(clientAdapter, _cancellationToken); - - // TODO: Check if Task.Run is required. - //return Task.Run(() => HandleConnectionAsync(clientAdapter, _cancellationToken), _cancellationToken); } - public async Task> GetClientStatusAsync() + public Task> GetClientStatusAsync() { - var result = new List(); + var result = new List(); - using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + foreach (var connection in _connections.Values) { - foreach (var session in _sessions.Values) - { - var status = new MqttClientSessionStatus(this, session); - session.FillStatus(status); + var clientStatus = new MqttClientStatus(connection, this); + connection.FillStatus(clientStatus); + + var sessionStatus = new MqttSessionStatus(connection.Session, this); + connection.Session.FillStatus(sessionStatus); + clientStatus.Session = sessionStatus; + + result.Add(clientStatus); + } - result.Add(status); - } + return Task.FromResult((IList)result); + } + + public Task> GetSessionStatusAsync() + { + var result = new List(); + + foreach (var session in _sessions.Values) + { + var sessionStatus = new MqttSessionStatus(session, this); + session.FillStatus(sessionStatus); + + result.Add(sessionStatus); } - - return result; + + return Task.FromResult((IList)result); } - public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + public void DispatchApplicationMessage(MqttApplicationMessage applicationMessage, MqttClientConnection sender) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - _messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken); + _messageQueue.Add(new MqttEnqueuedApplicationMessage(applicationMessage, sender), _cancellationToken); } - public async Task SubscribeAsync(string clientId, IEnumerable topicFilters) + public Task SubscribeAsync(string clientId, ICollection topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + if (!_sessions.TryGetValue(clientId, out var session)) { - if (!_sessions.TryGetValue(clientId, out var session)) - { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } - - await session.SubscribeAsync(topicFilters).ConfigureAwait(false); + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); } + + return session.SubscribeAsync(topicFilters, _retainedMessagesManager); } - public async Task UnsubscribeAsync(string clientId, IEnumerable topicFilters) + public Task UnsubscribeAsync(string clientId, IEnumerable topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + if (!_sessions.TryGetValue(clientId, out var session)) { - if (!_sessions.TryGetValue(clientId, out var session)) - { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } - - await session.UnsubscribeAsync(topicFilters).ConfigureAwait(false); + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); } + + return session.UnsubscribeAsync(topicFilters); } public async Task DeleteSessionAsync(string clientId) { - using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + if (_connections.TryGetValue(clientId, out var connection)) { - if (_sessions.TryGetValue(clientId, out var session)) - { - session.Dispose(); - } + await connection.StopAsync(); + } - _sessions.Remove(clientId); + if (_sessions.TryRemove(clientId, out var session)) + { } _logger.Verbose("Session for client '{0}' deleted.", clientId); @@ -179,7 +180,7 @@ namespace MQTTnet.Server { if (interceptorContext.CloseConnection) { - await enqueuedApplicationMessage.Sender.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); + await enqueuedApplicationMessage.Sender.StopAsync().ConfigureAwait(false); } if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) @@ -197,15 +198,12 @@ namespace MQTTnet.Server await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); } - using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + foreach (var clientSession in _sessions.Values) { - foreach (var clientSession in _sessions.Values) - { - await clientSession.EnqueueApplicationMessageAsync( - enqueuedApplicationMessage.Sender, - enqueuedApplicationMessage.ApplicationMessage, - false).ConfigureAwait(false); - } + clientSession.EnqueueApplicationMessage( + enqueuedApplicationMessage.ApplicationMessage, + sender?.ClientId, + false); } } catch (OperationCanceledException) @@ -219,6 +217,7 @@ namespace MQTTnet.Server private async Task HandleConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) { + var disconnectType = MqttClientDisconnectType.NotClean; var clientId = string.Empty; try @@ -232,22 +231,31 @@ namespace MQTTnet.Server clientId = connectPacket.ClientId; - var connectReturnCode = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); - if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) + var validatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); + + if (validatorContext.ReturnCode != MqttConnectReturnCode.ConnectionAccepted) { + // TODO: Move to channel adapter data converter. + + // Send failure response here without preparing a session. The result for a successful connect + // will be sent from the session itself. await channelAdapter.SendPacketAsync( new MqttConnAckPacket { - ReturnCode = connectReturnCode, + ReturnCode = validatorContext.ReturnCode, ReasonCode = MqttConnectReasonCode.NotAuthorized }, + _options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); return; } - var session = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); - await session.RunAsync(connectPacket, channelAdapter).ConfigureAwait(false); + var connection = await CreateConnectionAsync(channelAdapter, connectPacket).ConfigureAwait(false); + + _eventDispatcher.OnClientConnected(clientId); + + disconnectType = await connection.RunAsync().ConfigureAwait(false); } catch (OperationCanceledException) { @@ -258,26 +266,29 @@ namespace MQTTnet.Server } finally { - await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); - - if (!_options.EnablePersistentSessions) + _connections.TryRemove(clientId, out _); + + ////connection?.ReferenceCounter.Decrement(); + ////if (connection?.ReferenceCounter.HasReferences == true) + ////{ + //// disconnectType = MqttClientDisconnectType.Takeover; + ////} + ////else { - // TODO: Check if the session will be used later. - // Consider reference counter or "Recycle" property - // Or add timer (will be required for MQTTv5 (session life time) "IsActiveProperty". - //öö - //await DeleteSessionAsync(clientId).ConfigureAwait(false); + if (!_options.EnablePersistentSessions) + { + await DeleteSessionAsync(clientId).ConfigureAwait(false); + } } + + await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + + _eventDispatcher.OnClientDisconnected(clientId, disconnectType); } } - private async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) + private async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) { - if (_options.ConnectionValidator == null) - { - return MqttConnectReturnCode.ConnectionAccepted; - } - var context = new MqttConnectionValidatorContext( connectPacket.ClientId, connectPacket.Username, @@ -285,47 +296,69 @@ namespace MQTTnet.Server connectPacket.WillMessage, clientAdapter.Endpoint); - await _options.ConnectionValidator.ValidateConnection(context).ConfigureAwait(false); - return context.ReturnCode; + var connectionValidator = _options.ConnectionValidator; + + if (connectionValidator == null) + { + context.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; + return context; + } + + await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false); + return context; } - private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) + private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1); + + private async Task CreateConnectionAsync(IMqttChannelAdapter channelAdapter, MqttConnectPacket connectPacket) { - using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false); + try { - var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); + var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session); + + var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); + if (isConnectionPresent) + { + await existingConnection.StopAsync().ConfigureAwait(false); + } + if (isSessionPresent) { if (connectPacket.CleanSession) { - await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); - - clientSession.Dispose(); - clientSession = null; - - _logger.Verbose("Stopped existing session of client '{0}'.", connectPacket.ClientId); + // TODO: Check if required. + //session.Dispose(); + session = null; + + _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId); } else { - await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); - _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId); } } - if (clientSession == null) + if (session == null) { - clientSession = new MqttClientSession(connectPacket.ClientId, _options, this, _retainedMessagesManager, _eventDispatcher, _logger); - _sessions[connectPacket.ClientId] = clientSession; - + session = new MqttClientSession(connectPacket.ClientId, _eventDispatcher, _options); _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } - return clientSession; + var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, _logger); + + _connections[connection.ClientId] = connection; + _sessions[session.ClientId] = session; + + return connection; + } + finally + { + _createConnectionGate.Release(); } } - private async Task InterceptApplicationMessageAsync(MqttClientSession sender, MqttApplicationMessage applicationMessage) + private async Task InterceptApplicationMessageAsync(MqttClientConnection sender, MqttApplicationMessage applicationMessage) { var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index 3186104..d4c2eea 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -14,7 +14,7 @@ namespace MQTTnet.Server private readonly MqttServerEventDispatcher _eventDispatcher; private readonly string _clientId; - public MqttClientSubscriptionsManager(string clientId, IMqttServerOptions options, MqttServerEventDispatcher eventDispatcher) + public MqttClientSubscriptionsManager(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions options) { _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); _options = options ?? throw new ArgumentNullException(nameof(options)); @@ -68,7 +68,29 @@ namespace MQTTnet.Server return result; } - public MqttUnsubAckPacket Unsubscribe(MqttUnsubscribePacket unsubscribePacket) + public async Task SubscribeAsync(IEnumerable topicFilters) + { + foreach (var topicFilter in topicFilters) + { + var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); + if (!interceptorContext.AcceptSubscription) + { + continue; + } + + if (interceptorContext.AcceptSubscription) + { + lock (_subscriptions) + { + _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + } + + _eventDispatcher.OnClientSubscribedTopic(_clientId, topicFilter); + } + } + } + + public Task UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket) { if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket)); @@ -94,7 +116,22 @@ namespace MQTTnet.Server } } - return unsubAckPacket; + return Task.FromResult(unsubAckPacket); + } + + public Task UnsubscribeAsync(IEnumerable topicFilters) + { + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + lock (_subscriptions) + { + foreach (var topicFilter in topicFilters) + { + _subscriptions.Remove(topicFilter); + } + } + + return Task.FromResult(0); } public CheckSubscriptionsResult CheckSubscriptions(string topic, MqttQualityOfServiceLevel qosLevel) diff --git a/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs index 20ff2fe..4711269 100644 --- a/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs +++ b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs @@ -2,13 +2,13 @@ { public class MqttEnqueuedApplicationMessage { - public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) + public MqttEnqueuedApplicationMessage(MqttApplicationMessage applicationMessage, MqttClientConnection sender) { Sender = sender; ApplicationMessage = applicationMessage; } - public MqttClientSession Sender { get; } + public MqttClientConnection Sender { get; } public MqttApplicationMessage ApplicationMessage { get; } } diff --git a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs index 52533d1..7fb7b53 100644 --- a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs +++ b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs @@ -101,29 +101,34 @@ namespace MQTTnet.Server public async Task> GetSubscribedMessagesAsync(ICollection topicFilters) { - var retainedMessages = new List(); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + var matchingRetainedMessages = new List(); + + List retainedMessages; using (await _messagesLock.WaitAsync().ConfigureAwait(false)) { - foreach (var retainedMessage in _messages.Values) + retainedMessages = _messages.Values.ToList(); + } + + foreach (var retainedMessage in retainedMessages) + { + foreach (var topicFilter in topicFilters) { - foreach (var topicFilter in topicFilters) + if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic)) { - if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic)) - { - continue; - } - - retainedMessages.Add(retainedMessage); - break; + continue; } + + matchingRetainedMessages.Add(retainedMessage); + break; } } - - return retainedMessages; + + return matchingRetainedMessages; } - public async Task> GetMessagesAsync() + public async Task> GetMessagesAsync() { using (await _messagesLock.WaitAsync().ConfigureAwait(false)) { diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index d2145bd..4d36e4c 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -7,6 +7,7 @@ using MQTTnet.Adapter; using MQTTnet.Client.Publishing; using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; +using MQTTnet.Server.Status; namespace MQTTnet.Server { @@ -59,17 +60,22 @@ namespace MQTTnet.Server public IMqttServerOptions Options { get; private set; } - public Task> GetClientSessionsStatusAsync() + public Task> GetClientStatusAsync() { return _clientSessionsManager.GetClientStatusAsync(); } - public IList GetRetainedMessages() + public Task> GetSessionStatusAsync() { - return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult(); + return _clientSessionsManager.GetSessionStatusAsync(); } - public Task SubscribeAsync(string clientId, IEnumerable topicFilters) + public Task> GetRetainedMessagesAsync() + { + return _retainedMessagesManager.GetMessagesAsync(); + } + + public Task SubscribeAsync(string clientId, ICollection topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -77,7 +83,7 @@ namespace MQTTnet.Server return _clientSessionsManager.SubscribeAsync(clientId, topicFilters); } - public Task UnsubscribeAsync(string clientId, IEnumerable topicFilters) + public Task UnsubscribeAsync(string clientId, ICollection topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -85,13 +91,13 @@ namespace MQTTnet.Server return _clientSessionsManager.UnsubscribeAsync(clientId, topicFilters); } - public Task PublishAsync(MqttApplicationMessage applicationMessage) + public Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started."); - _clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage); + _clientSessionsManager.DispatchApplicationMessage(applicationMessage, null); return Task.FromResult(new MqttClientPublishResult()); } diff --git a/Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs b/Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs index 954d10d..8be08c8 100644 --- a/Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs +++ b/Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs @@ -23,7 +23,7 @@ namespace MQTTnet.Server _callback = callback ?? throw new ArgumentNullException(nameof(callback)); } - public Task ValidateConnection(MqttConnectionValidatorContext context) + public Task ValidateConnectionAsync(MqttConnectionValidatorContext context) { return _callback(context); } diff --git a/Source/MQTTnet/Server/MqttServerEventDispatcher.cs b/Source/MQTTnet/Server/MqttServerEventDispatcher.cs index 8fd5652..03ed93d 100644 --- a/Source/MQTTnet/Server/MqttServerEventDispatcher.cs +++ b/Source/MQTTnet/Server/MqttServerEventDispatcher.cs @@ -24,9 +24,9 @@ namespace MQTTnet.Server ClientUnsubscribedTopic?.Invoke(this, new MqttClientUnsubscribedTopicEventArgs(clientId, topicFilter)); } - public void OnClientDisconnected(string clientId, bool wasCleanDisconnect) + public void OnClientDisconnected(string clientId, MqttClientDisconnectType disconnectType) { - ClientDisconnected?.Invoke(this, new MqttClientDisconnectedEventArgs(clientId, wasCleanDisconnect)); + ClientDisconnected?.Invoke(this, new MqttClientDisconnectedEventArgs(clientId, disconnectType)); } public void OnApplicationMessageReceived(string senderClientId, MqttApplicationMessage applicationMessage) diff --git a/Source/MQTTnet/Server/MqttServerExtensions.cs b/Source/MQTTnet/Server/MqttServerExtensions.cs index e320f13..c705071 100644 --- a/Source/MQTTnet/Server/MqttServerExtensions.cs +++ b/Source/MQTTnet/Server/MqttServerExtensions.cs @@ -1,5 +1,9 @@ using System; +using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; +using MQTTnet.Client; +using MQTTnet.Client.Publishing; using MQTTnet.Protocol; namespace MQTTnet.Server @@ -41,5 +45,89 @@ namespace MQTTnet.Server return server.UnsubscribeAsync(clientId, topicFilters); } + + public static async Task PublishAsync(this IMqttServer server, IEnumerable applicationMessages) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); + + foreach (var applicationMessage in applicationMessages) + { + await server.PublishAsync(applicationMessage).ConfigureAwait(false); + } + } + + public static Task PublishAsync(this IMqttServer server, MqttApplicationMessage applicationMessage) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + + return server.PublishAsync(applicationMessage, CancellationToken.None); + } + + public static async Task PublishAsync(this IMqttServer server, params MqttApplicationMessage[] applicationMessages) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); + + foreach (var applicationMessage in applicationMessages) + { + await server.PublishAsync(applicationMessage, CancellationToken.None).ConfigureAwait(false); + } + } + + public static Task PublishAsync(this IMqttServer server, string topic) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return server.PublishAsync(builder => builder + .WithTopic(topic)); + } + + public static Task PublishAsync(this IMqttServer server, string topic, string payload) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return server.PublishAsync(builder => builder + .WithTopic(topic) + .WithPayload(payload)); + } + + public static Task PublishAsync(this IMqttServer server, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return server.PublishAsync(builder => builder + .WithTopic(topic) + .WithPayload(payload) + .WithQualityOfServiceLevel(qualityOfServiceLevel)); + } + + public static Task PublishAsync(this IMqttServer server, string topic, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, bool retain) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return server.PublishAsync(builder => builder + .WithTopic(topic) + .WithPayload(payload) + .WithQualityOfServiceLevel(qualityOfServiceLevel) + .WithRetainFlag(retain)); + } + + public static Task PublishAsync(this IMqttServer server, Func builder, CancellationToken cancellationToken) + { + var message = builder(new MqttApplicationMessageBuilder()).Build(); + return server.PublishAsync(message, cancellationToken); + } + + public static Task PublishAsync(this IMqttServer server, Func builder) + { + var message = builder(new MqttApplicationMessageBuilder()).Build(); + return server.PublishAsync(message, CancellationToken.None); + } } } diff --git a/Source/MQTTnet/Server/PrepareClientSessionResult.cs b/Source/MQTTnet/Server/PrepareClientSessionResult.cs index 9a655be..7509037 100644 --- a/Source/MQTTnet/Server/PrepareClientSessionResult.cs +++ b/Source/MQTTnet/Server/PrepareClientSessionResult.cs @@ -4,6 +4,6 @@ { public bool IsExistingSession { get; set; } - public MqttClientSession Session { get; set; } + public MqttClientConnection Session { get; set; } } } diff --git a/Source/MQTTnet/Server/Status/IMqttClientStatus.cs b/Source/MQTTnet/Server/Status/IMqttClientStatus.cs new file mode 100644 index 0000000..b751f29 --- /dev/null +++ b/Source/MQTTnet/Server/Status/IMqttClientStatus.cs @@ -0,0 +1,31 @@ +using System; +using System.Threading.Tasks; +using MQTTnet.Formatter; + +namespace MQTTnet.Server.Status +{ + public interface IMqttClientStatus + { + string ClientId { get; } + + string Endpoint { get; } + + MqttProtocolVersion ProtocolVersion { get; } + + DateTime LastPacketReceivedTimestamp { get; } + + DateTime LastNonKeepAlivePacketReceivedTimestamp { get; } + + long ReceivedApplicationMessagesCount { get; } + + long SentApplicationMessagesCount { get; } + + long ReceivedPacketsCount { get; } + + long SentPacketsCount { get; } + + IMqttSessionStatus Session { get; } + + Task DisconnectAsync(); + } +} diff --git a/Source/MQTTnet/Server/Status/IMqttSessionStatus.cs b/Source/MQTTnet/Server/Status/IMqttSessionStatus.cs new file mode 100644 index 0000000..ada8efe --- /dev/null +++ b/Source/MQTTnet/Server/Status/IMqttSessionStatus.cs @@ -0,0 +1,17 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Server.Status +{ + public interface IMqttSessionStatus + { + string ClientId { get; set; } + + bool IsConnected { get; } + + long PendingApplicationMessagesCount { get; set; } + + Task ClearPendingApplicationMessagesAsync(); + + Task DeleteAsync(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Server/Status/MqttClientStatus.cs b/Source/MQTTnet/Server/Status/MqttClientStatus.cs new file mode 100644 index 0000000..79daa46 --- /dev/null +++ b/Source/MQTTnet/Server/Status/MqttClientStatus.cs @@ -0,0 +1,43 @@ +using System; +using System.Threading.Tasks; +using MQTTnet.Formatter; + +namespace MQTTnet.Server.Status +{ + public class MqttClientStatus : IMqttClientStatus + { + private readonly MqttClientSessionsManager _sessionsManager; + private readonly MqttClientConnection _connection; + + public MqttClientStatus(MqttClientConnection connection, MqttClientSessionsManager sessionsManager) + { + _connection = connection; + _sessionsManager = sessionsManager; + } + + public string ClientId { get; set; } + + public string Endpoint { get; set; } + + public MqttProtocolVersion ProtocolVersion { get; set; } + + public DateTime LastPacketReceivedTimestamp { get; set; } + + public DateTime LastNonKeepAlivePacketReceivedTimestamp { get; set; } + + public long ReceivedApplicationMessagesCount { get; set; } + + public long SentApplicationMessagesCount { get; set; } + + public long ReceivedPacketsCount { get; set; } + + public long SentPacketsCount { get; set; } + + public IMqttSessionStatus Session { get; set; } + + public Task DisconnectAsync() + { + return _connection.StopAsync(); + } + } +} diff --git a/Source/MQTTnet/Server/Status/MqttSessionStatus.cs b/Source/MQTTnet/Server/Status/MqttSessionStatus.cs new file mode 100644 index 0000000..401fc37 --- /dev/null +++ b/Source/MQTTnet/Server/Status/MqttSessionStatus.cs @@ -0,0 +1,36 @@ +using System; +using System.Threading.Tasks; + +namespace MQTTnet.Server.Status +{ + public class MqttSessionStatus : IMqttSessionStatus + { + private readonly MqttClientSession _session; + private readonly MqttClientSessionsManager _sessionsManager; + + public MqttSessionStatus(MqttClientSession session, MqttClientSessionsManager sessionsManager) + { + _session = session ?? throw new ArgumentNullException(nameof(session)); + _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); + } + + public string ClientId { get; set; } + + public long PendingApplicationMessagesCount { get; set; } + + public bool IsConnected { get; set; } + + public DateTime CreatedTimestamp { get; set; } + + public Task DeleteAsync() + { + return _sessionsManager.DeleteSessionAsync(ClientId); + } + + public Task ClearPendingApplicationMessagesAsync() + { + _session.ApplicationMessagesQueue.Clear(); + return Task.FromResult(0); + } + } +} diff --git a/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs b/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs index 9edcdc2..8fb74ae 100644 --- a/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs +++ b/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs @@ -25,7 +25,7 @@ namespace MQTTnet.AspNetCore.Tests pipe.Receive.Writer.Complete(); - await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(TimeSpan.FromSeconds(1), CancellationToken.None)); + await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(TimeSpan.Zero, CancellationToken.None)); } [TestMethod] @@ -41,7 +41,7 @@ namespace MQTTnet.AspNetCore.Tests { for (int i = 0; i < 100; i++) { - await ctx.SendPacketAsync(new MqttPublishPacket(), CancellationToken.None).ConfigureAwait(false); + await ctx.SendPacketAsync(new MqttPublishPacket(), TimeSpan.Zero, CancellationToken.None).ConfigureAwait(false); } })); diff --git a/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs b/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs index ea03570..3a5c696 100644 --- a/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs @@ -66,7 +66,7 @@ namespace MQTTnet.Benchmarks for (var i = 0; i < 10000; i++) { - _channelAdapter.SendPacketAsync(_packet, CancellationToken.None).GetAwaiter().GetResult(); + _channelAdapter.SendPacketAsync(_packet, TimeSpan.Zero, CancellationToken.None).GetAwaiter().GetResult(); } _stream.Position = 0; diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs index 346626e..d713dfc 100644 --- a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -85,7 +85,7 @@ namespace MQTTnet.Benchmarks throw new NotImplementedException(); } - public Task DisconnectAsync() + public Task DisconnectAsync(CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index a2f7562..d1b5c1c 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -4,6 +4,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Exceptions; +using MQTTnet.Internal; namespace MQTTnet.Tests { @@ -14,20 +15,20 @@ namespace MQTTnet.Tests [TestMethod] public async Task TimeoutAfter() { - await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MqttTaskTimeout.WaitAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] public async Task TimeoutAfterWithResult() { - await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MqttTaskTimeout.WaitAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [TestMethod] public async Task TimeoutAfterCompleteInTime() { - var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); + var result = await MqttTaskTimeout.WaitAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); Assert.AreEqual(5, result); } @@ -36,7 +37,7 @@ namespace MQTTnet.Tests { try { - await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => + await MqttTaskTimeout.WaitAsync(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -55,7 +56,7 @@ namespace MQTTnet.Tests { try { - await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => + await MqttTaskTimeout.WaitAsync(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -76,7 +77,7 @@ namespace MQTTnet.Tests var tasks = Enumerable.Range(0, 100000) .Select(i => { - return MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); + return MqttTaskTimeout.WaitAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); }); await Task.WhenAll(tasks); diff --git a/Tests/MQTTnet.Core.Tests/MQTTv5/MqttClientTests.cs b/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs similarity index 95% rename from Tests/MQTTnet.Core.Tests/MQTTv5/MqttClientTests.cs rename to Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs index 3d7c97c..6c13dbb 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTv5/MqttClientTests.cs +++ b/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs @@ -16,7 +16,7 @@ namespace MQTTnet.Tests.MQTTv5 public class Client_Tests { [TestMethod] - public async Task Client_Connect() + public async Task Connect() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -33,7 +33,7 @@ namespace MQTTnet.Tests.MQTTv5 } [TestMethod] - public async Task Client_Connect_And_Disconnect() + public async Task Connect_And_Disconnect() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -52,7 +52,7 @@ namespace MQTTnet.Tests.MQTTv5 } [TestMethod] - public async Task Client_Subscribe() + public async Task Subscribe() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -75,7 +75,7 @@ namespace MQTTnet.Tests.MQTTv5 } [TestMethod] - public async Task Client_Unsubscribe() + public async Task Unsubscribe() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -99,7 +99,7 @@ namespace MQTTnet.Tests.MQTTv5 } [TestMethod] - public async Task Client_Publish_QoS0() + public async Task Publish_QoS_0() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -121,7 +121,7 @@ namespace MQTTnet.Tests.MQTTv5 } [TestMethod] - public async Task Client_Publish_QoS1() + public async Task Publish_QoS_1() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -143,7 +143,7 @@ namespace MQTTnet.Tests.MQTTv5 } [TestMethod] - public async Task Client_Publish_QoS2() + public async Task Publish_QoS_2() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -165,7 +165,7 @@ namespace MQTTnet.Tests.MQTTv5 } [TestMethod] - public async Task Client_Publish_With_Properties() + public async Task Publish_With_Properties() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); diff --git a/Tests/MQTTnet.Core.Tests/TestSetup.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs similarity index 88% rename from Tests/MQTTnet.Core.Tests/TestSetup.cs rename to Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs index b1515ca..d15d928 100644 --- a/Tests/MQTTnet.Core.Tests/TestSetup.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs @@ -7,9 +7,9 @@ using MQTTnet.Client.Options; using MQTTnet.Diagnostics; using MQTTnet.Server; -namespace MQTTnet.Tests +namespace MQTTnet.Tests.Mockups { - public class TestSetup : IDisposable + public class TestEnvironment : IDisposable { private readonly MqttFactory _mqttFactory = new MqttFactory(); private readonly List _clients = new List(); @@ -24,9 +24,12 @@ namespace MQTTnet.Tests private IMqttServer _server; public bool IgnoreClientLogErrors { get; set; } + public bool IgnoreServerLogErrors { get; set; } - public TestSetup() + public int ServerPort { get; set; } = 1888; + + public TestEnvironment() { _serverLogger.LogMessagePublished += (s, e) => { @@ -51,6 +54,11 @@ namespace MQTTnet.Tests }; } + public IMqttClient CreateClient() + { + return _mqttFactory.CreateMqttClient(_clientLogger); + } + public Task StartServerAsync() { return StartServerAsync(new MqttServerOptionsBuilder()); @@ -64,7 +72,7 @@ namespace MQTTnet.Tests } _server = _mqttFactory.CreateMqttServer(_serverLogger); - await _server.StartAsync(options.WithDefaultEndpointPort(1888).Build()); + await _server.StartAsync(options.WithDefaultEndpointPort(ServerPort).Build()); return _server; } @@ -77,7 +85,7 @@ namespace MQTTnet.Tests public async Task ConnectClientAsync(MqttClientOptionsBuilder options) { var client = _mqttFactory.CreateMqttClient(_clientLogger); - await client.ConnectAsync(options.WithTcpServer("localhost", 1888).Build()); + await client.ConnectAsync(options.WithTcpServer("localhost", ServerPort).Build()); _clients.Add(client); return client; @@ -106,17 +114,16 @@ namespace MQTTnet.Tests { foreach (var mqttClient in _clients) { - mqttClient?.DisconnectAsync().GetAwaiter().GetResult(); mqttClient?.Dispose(); } - _server.StopAsync().GetAwaiter().GetResult(); + _server?.StopAsync().GetAwaiter().GetResult(); ThrowIfLogErrors(); if (_exceptions.Any()) { - throw new Exception($"{_exceptions.Count} exceptions tracked."); + throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); } } diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs similarity index 67% rename from Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs rename to Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs index 4a7475d..73676dd 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs @@ -6,7 +6,7 @@ using MQTTnet.Adapter; using MQTTnet.Formatter; using MQTTnet.Packets; -namespace MQTTnet.Tests +namespace MQTTnet.Tests.Mockups { public class TestMqttCommunicationAdapter : IMqttChannelAdapter { @@ -35,7 +35,7 @@ namespace MQTTnet.Tests return Task.FromResult(0); } - public Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) + public Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); @@ -44,30 +44,11 @@ namespace MQTTnet.Tests return Task.FromResult(0); } - public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) + public Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); - - if (timeout > TimeSpan.Zero) - { - using (var timeoutCts = new CancellationTokenSource(timeout)) - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) - { - return await Task.Run(() => - { - try - { - return _incomingPackets.Take(cts.Token); - } - catch - { - return null; - } - }, cts.Token); - } - } - - return await Task.Run(() => + + return Task.Run(() => { try { diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapterFactory.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapterFactory.cs similarity index 92% rename from Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapterFactory.cs rename to Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapterFactory.cs index 0263ff6..ff47f71 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapterFactory.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapterFactory.cs @@ -1,9 +1,8 @@ using MQTTnet.Adapter; -using MQTTnet.Client; using MQTTnet.Client.Options; using MQTTnet.Diagnostics; -namespace MQTTnet.Tests +namespace MQTTnet.Tests.Mockups { public class TestMqttCommunicationAdapterFactory : IMqttClientAdapterFactory { diff --git a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttServerAdapter.cs similarity index 98% rename from Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs rename to Tests/MQTTnet.Core.Tests/Mockups/TestMqttServerAdapter.cs index cfcba04..014a906 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttServerAdapter.cs @@ -7,7 +7,7 @@ using MQTTnet.Client.Options; using MQTTnet.Diagnostics; using MQTTnet.Server; -namespace MQTTnet.Tests +namespace MQTTnet.Tests.Mockups { public class TestMqttServerAdapter : IMqttServerAdapter { diff --git a/Tests/MQTTnet.Core.Tests/TestServerExtensions.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestServerExtensions.cs similarity index 97% rename from Tests/MQTTnet.Core.Tests/TestServerExtensions.cs rename to Tests/MQTTnet.Core.Tests/Mockups/TestServerExtensions.cs index 5bfa2c8..1cbd1b7 100644 --- a/Tests/MQTTnet.Core.Tests/TestServerExtensions.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestServerExtensions.cs @@ -2,7 +2,7 @@ using MQTTnet.Client; using MQTTnet.Server; -namespace MQTTnet.Tests +namespace MQTTnet.Tests.Mockups { public static class TestServerExtensions { diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestServerStorage.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestServerStorage.cs new file mode 100644 index 0000000..55ca6bb --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestServerStorage.cs @@ -0,0 +1,22 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using MQTTnet.Server; + +namespace MQTTnet.Tests.Mockups +{ + public class TestServerStorage : IMqttServerStorage + { + public IList Messages = new List(); + + public Task SaveRetainedMessagesAsync(IList messages) + { + Messages = messages; + return Task.CompletedTask; + } + + public Task> LoadRetainedMessagesAsync() + { + return Task.FromResult(Messages); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs index 9725fee..9604a61 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; using System.Linq; using System.Net.Sockets; -using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; @@ -11,50 +10,73 @@ using MQTTnet.Client.Receiving; using MQTTnet.Exceptions; using MQTTnet.Protocol; using MQTTnet.Server; +using MQTTnet.Tests.Mockups; namespace MQTTnet.Tests { [TestClass] - public class MqttClientTests + public class Client_Tests { [TestMethod] - public async Task Client_Disconnect_Exception() + public async Task Invalid_Connect_Throws_Exception() { var factory = new MqttFactory(); - var client = factory.CreateMqttClient(); - - Exception ex = null; - client.Disconnected += (s, e) => + using (var client = factory.CreateMqttClient()) { - ex = e.Exception; - }; + try + { + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("wrong-server").Build()); - try - { - await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("wrong-server").Build()); + Assert.Fail("Must fail!"); + } + catch (Exception exception) + { + Assert.IsNotNull(exception); + Assert.IsInstanceOfType(exception, typeof(MqttCommunicationException)); + Assert.IsInstanceOfType(exception.InnerException, typeof(SocketException)); + } } - catch + } + + [TestMethod] + public async Task Disconnect_Event_Contains_Exception() + { + var factory = new MqttFactory(); + using (var client = factory.CreateMqttClient()) { - } + Exception ex = null; + client.Disconnected += (s, e) => + { + ex = e.Exception; + }; + + try + { + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("wrong-server").Build()); + } + catch + { + } - Assert.IsNotNull(ex); - Assert.IsInstanceOfType(ex, typeof(MqttCommunicationException)); - Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException)); + Assert.IsNotNull(ex); + Assert.IsInstanceOfType(ex, typeof(MqttCommunicationException)); + Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException)); + } } [TestMethod] - public async Task Client_Preserve_Message_Order() + public async Task Preserve_Message_Order() { // The messages are sent in reverse or to ensure that the delay in the handler // needs longer for the first messages and later messages may be processed earlier (if there // is an issue). const int MessagesCount = 50; - using (var testSetup = new TestSetup()) + using (var testEnvironment = new TestEnvironment()) { - await testSetup.StartServerAsync(); + await testEnvironment.StartServerAsync(); - var client1 = await testSetup.ConnectClientAsync(); + var client1 = await testEnvironment.ConnectClientAsync(); await client1.SubscribeAsync("x"); var receivedValues = new List(); @@ -72,7 +94,7 @@ namespace MQTTnet.Tests client1.UseReceivedApplicationMessageHandler(Handler1); - var client2 = await testSetup.ConnectClientAsync(); + var client2 = await testEnvironment.ConnectClientAsync(); for (var i = MessagesCount; i > 0; i--) { await client2.PublishAsync("x", i.ToString()); @@ -88,13 +110,13 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task Client_Send_Reply_For_Any_Received_Message() + public async Task Send_Reply_For_Any_Received_Message() { - using (var testSetup = new TestSetup()) + using (var testEnvironment = new TestEnvironment()) { - await testSetup.StartServerAsync(); + await testEnvironment.StartServerAsync(); - var client1 = await testSetup.ConnectClientAsync(); + var client1 = await testEnvironment.ConnectClientAsync(); await client1.SubscribeAsync("request/+"); async Task Handler1(MqttApplicationMessageHandlerContext context) @@ -104,7 +126,7 @@ namespace MQTTnet.Tests client1.UseReceivedApplicationMessageHandler(Handler1); - var client2 = await testSetup.ConnectClientAsync(); + var client2 = await testEnvironment.ConnectClientAsync(); await client2.SubscribeAsync("reply/#"); var replies = new List(); @@ -132,30 +154,26 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task Client_Publish() + public async Task Publish_With_Correct_Retain_Flag() { - var server = new MqttFactory().CreateMqttServer(); - - try + using (var testEnvironment = new TestEnvironment()) { - var receivedMessages = new List(); + await testEnvironment.StartServerAsync(); - await server.StartAsync(new MqttServerOptions()); + var receivedMessages = new List(); - var client1 = new MqttFactory().CreateMqttClient(); - client1.ApplicationMessageReceived += (_, e) => + var client1 = await testEnvironment.ConnectClientAsync(); + client1.UseReceivedApplicationMessageHandler(c => { lock (receivedMessages) { - receivedMessages.Add(e.ApplicationMessage); + receivedMessages.Add(c.ApplicationMessage); } - }; + }); - await client1.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1").Build()); await client1.SubscribeAsync("a"); - var client2 = new MqttFactory().CreateMqttClient(); - await client2.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1").Build()); + var client2 = await testEnvironment.ConnectClientAsync(); var message = new MqttApplicationMessageBuilder().WithTopic("a").WithRetainFlag().Build(); await client2.PublishAsync(message); @@ -164,26 +182,18 @@ namespace MQTTnet.Tests Assert.AreEqual(1, receivedMessages.Count); Assert.IsFalse(receivedMessages.First().Retain); // Must be false even if set above! } - finally - { - await server.StopAsync(); - } } [TestMethod] - public async Task Publish_Special_Content() + public async Task Subscribe_In_Callback_Events() { - var factory = new MqttFactory(); - var server = factory.CreateMqttServer(); - var serverOptions = new MqttServerOptionsBuilder().Build(); - - var receivedMessages = new List(); + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(); - var client = factory.CreateMqttClient(); + var receivedMessages = new List(); - try - { - await server.StartAsync(serverOptions); + var client = testEnvironment.CreateClient(); client.Connected += async (s, e) => { @@ -196,43 +206,37 @@ namespace MQTTnet.Tests await client.PublishAsync(msg.Build()); }; - client.ApplicationMessageReceived += (s, e) => + client.UseReceivedApplicationMessageHandler(c => { lock (receivedMessages) { - receivedMessages.Add(e.ApplicationMessage); + receivedMessages.Add(c.ApplicationMessage); } - }; + }); - await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost").Build()); + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost", testEnvironment.ServerPort).Build()); await Task.Delay(500); Assert.AreEqual(1, receivedMessages.Count); Assert.AreEqual("DA|18RS00SC00XI0000RV00R100R200R300R400L100L200L300L400Y100Y200AC0102031800BELK0000BM0000|", receivedMessages.First().ConvertPayloadToString()); } - finally - { - await server.StopAsync(); - } } [TestMethod] - public async Task Client_Exception_In_Application_Message_Handler() + public async Task Message_Send_Retry() { - using (var testSetup = new TestSetup()) + using (var testEnvironment = new TestEnvironment()) { - testSetup.IgnoreClientLogErrors = true; - testSetup.IgnoreServerLogErrors = true; + testEnvironment.IgnoreClientLogErrors = true; + testEnvironment.IgnoreServerLogErrors = true; - await testSetup.StartServerAsync( + await testEnvironment.StartServerAsync( new MqttServerOptionsBuilder() .WithPersistentSessions() - .WithDefaultCommunicationTimeout(TimeSpan.FromMilliseconds(50))); - - var client1 = await testSetup.ConnectClientAsync(new MqttClientOptionsBuilder() - .WithCleanSession(false)); + .WithDefaultCommunicationTimeout(TimeSpan.FromMilliseconds(250))); + var client1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithCleanSession(false)); await client1.SubscribeAsync("x", MqttQualityOfServiceLevel.AtLeastOnce); var retries = 0; @@ -241,16 +245,16 @@ namespace MQTTnet.Tests { retries++; - await Task.Delay(50); + await Task.Delay(1000); throw new Exception("Broken!"); } client1.UseReceivedApplicationMessageHandler(Handler1); - var client2 = await testSetup.ConnectClientAsync(); + var client2 = await testEnvironment.ConnectClientAsync(); await client2.PublishAsync("x"); - await Task.Delay(1000); + await Task.Delay(3000); // The server should disconnect clients which are not responding. Assert.IsFalse(client1.IsConnected); @@ -262,40 +266,5 @@ namespace MQTTnet.Tests Assert.AreEqual(2, retries); } } - - //#if DEBUG - // [TestMethod] - // public async Task Client_Cleanup_On_Authentification_Fails() - // { - // var channel = new TestMqttCommunicationAdapter(); - // var channel2 = new TestMqttCommunicationAdapter(); - // channel.Partner = channel2; - // channel2.Partner = channel; - - // Task.Run(async () => { - // var connect = await channel2.ReceivePacketAsync(TimeSpan.Zero, CancellationToken.None); - // await channel2.SendPacketAsync(new MqttConnAckPacket - // { - // ConnectReturnCode = Protocol.MqttConnectReturnCode.ConnectionRefusedNotAuthorized - // }, CancellationToken.None); - // }); - - // var fake = new TestMqttCommunicationAdapterFactory(channel); - - // var client = new MqttClient(fake, new MqttNetLogger()); - - // try - // { - // await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("any-server").Build()); - // } - // catch (Exception ex) - // { - // Assert.IsInstanceOfType(ex, typeof(MqttConnectingFailedException)); - // } - - // Assert.IsTrue(client._packetReceiverTask == null || client._packetReceiverTask.IsCompleted, "receive loop not completed"); - // Assert.IsTrue(client._keepAliveMessageSenderTask == null || client._keepAliveMessageSenderTask.IsCompleted, "keepalive loop not completed"); - // } - //#endif } } diff --git a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs index 75b088a..493f5ab 100644 --- a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs @@ -7,6 +7,7 @@ using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Packets; using MQTTnet.Server; +using MQTTnet.Server.Status; namespace MQTTnet.Tests { @@ -44,9 +45,9 @@ namespace MQTTnet.Tests // Simulate traffic. Thread.Sleep(1000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification. - monitor.PacketReceived(new MqttPublishPacket()); + monitor.PacketReceived(); Thread.Sleep(1000); - monitor.PacketReceived(new MqttPublishPacket()); + monitor.PacketReceived(); Thread.Sleep(1000); Assert.AreEqual(0, clientSession.StopCalledCount); @@ -62,12 +63,12 @@ namespace MQTTnet.Tests public int StopCalledCount { get; private set; } - public void FillStatus(MqttClientSessionStatus status) + public void FillStatus(MqttClientStatus status) { throw new NotSupportedException(); } - public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + public void EnqueueApplicationMessage(MqttClientConnection senderClientSession, MqttApplicationMessage applicationMessage) { throw new NotSupportedException(); } @@ -82,7 +83,7 @@ namespace MQTTnet.Tests throw new NotSupportedException(); } - public Task StopAsync(MqttClientDisconnectType disconnectType) + public Task StopAsync() { StopCalledCount++; return Task.FromResult(0); diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs deleted file mode 100644 index fc66c2d..0000000 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ /dev/null @@ -1,1031 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Net.Sockets; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Adapter; -using MQTTnet.Client; -using MQTTnet.Client.Options; -using MQTTnet.Diagnostics; -using MQTTnet.Protocol; -using MQTTnet.Server; - -namespace MQTTnet.Tests -{ - [TestClass] - public class MqttServerTests - { - [TestMethod] - public async Task MqttServer_PublishSimple_AtMostOnce() - { - await TestPublishAsync( - "A/B/C", - MqttQualityOfServiceLevel.AtMostOnce, - "A/B/C", - MqttQualityOfServiceLevel.AtMostOnce, - 1); - } - - [TestMethod] - public async Task MqttServer_PublishSimple_AtLeastOnce() - { - await TestPublishAsync( - "A/B/C", - MqttQualityOfServiceLevel.AtLeastOnce, - "A/B/C", - MqttQualityOfServiceLevel.AtLeastOnce, - 1); - } - - [TestMethod] - public async Task MqttServer_PublishSimple_ExactlyOnce() - { - await TestPublishAsync( - "A/B/C", - MqttQualityOfServiceLevel.ExactlyOnce, - "A/B/C", - MqttQualityOfServiceLevel.ExactlyOnce, - 1); - } - - [TestMethod] - public async Task MqttServer_Will_Message() - { - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var receivedMessagesCount = 0; - try - { - await s.StartAsync(new MqttServerOptions()); - - var willMessage = new MqttApplicationMessageBuilder().WithTopic("My/last/will").WithAtMostOnceQoS().Build(); - var c1 = await serverAdapter.ConnectTestClient("c1"); - var c2 = await serverAdapter.ConnectTestClient("c2", willMessage); - - c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic("#").Build()); - - await c2.DisconnectAsync(); - - await Task.Delay(1000); - - await c1.DisconnectAsync(); - } - finally - { - await s.StopAsync(); - } - - Assert.AreEqual(0, receivedMessagesCount); - } - - [TestMethod] - public async Task MqttServer_Subscribe_Unsubscribe() - { - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var receivedMessagesCount = 0; - - try - { - await s.StartAsync(new MqttServerOptions()); - - var c1 = await serverAdapter.ConnectTestClient("c1"); - var c2 = await serverAdapter.ConnectTestClient("c2"); - c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - - var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); - - await c2.PublishAsync(message); - await Task.Delay(1000); - Assert.AreEqual(0, receivedMessagesCount); - - var subscribeEventCalled = false; - s.ClientSubscribedTopic += (_, e) => - { - subscribeEventCalled = e.TopicFilter.Topic == "a" && e.ClientId == "c1"; - }; - - await c1.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - await Task.Delay(500); - Assert.IsTrue(subscribeEventCalled, "Subscribe event not called."); - - await c2.PublishAsync(message); - await Task.Delay(500); - Assert.AreEqual(1, receivedMessagesCount); - - var unsubscribeEventCalled = false; - s.ClientUnsubscribedTopic += (_, e) => - { - unsubscribeEventCalled = e.TopicFilter == "a" && e.ClientId == "c1"; - }; - - await c1.UnsubscribeAsync("a"); - await Task.Delay(500); - Assert.IsTrue(unsubscribeEventCalled, "Unsubscribe event not called."); - - await c2.PublishAsync(message); - await Task.Delay(1000); - Assert.AreEqual(1, receivedMessagesCount); - } - finally - { - await s.StopAsync(); - } - await Task.Delay(500); - - Assert.AreEqual(1, receivedMessagesCount); - } - - [TestMethod] - public async Task MqttServer_Publish_From_Server() - { - using (var testSetup = new TestSetup()) - { - var server = await testSetup.StartServerAsync(); - - var receivedMessagesCount = 0; - - var client = await testSetup.ConnectClientAsync(); - client.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); - - var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); - await client.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - - await server.PublishAsync(message); - - await Task.Delay(1000); - await server.StopAsync(); - - Assert.AreEqual(1, receivedMessagesCount); - } - } - - [TestMethod] - public async Task MqttServer_Publish_Multiple_Clients() - { - var s = new MqttFactory().CreateMqttServer(); - var receivedMessagesCount = 0; - var locked = new object(); - - var clientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("localhost") - .Build(); - - var clientOptions2 = new MqttClientOptionsBuilder() - .WithTcpServer("localhost") - .Build(); - - try - { - await s.StartAsync(new MqttServerOptions()); - - var c1 = new MqttFactory().CreateMqttClient(); - var c2 = new MqttFactory().CreateMqttClient(); - - await c1.ConnectAsync(clientOptions); - await c2.ConnectAsync(clientOptions2); - - c1.ApplicationMessageReceived += (_, __) => - { - lock (locked) - { - receivedMessagesCount++; - } - }; - - c2.ApplicationMessageReceived += (_, __) => - { - lock (locked) - { - receivedMessagesCount++; - } - }; - - var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); - await c1.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - await c2.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - - //await Task.WhenAll(Publish(c1, message), Publish(c2, message)); - await Publish(c1, message); - - await Task.Delay(500); - } - finally - { - await s.StopAsync(); - } - - Assert.AreEqual(2000, receivedMessagesCount); - } - - [TestMethod] - public async Task MqttServer_Session_Takeover() - { - var server = new MqttFactory().CreateMqttServer(); - try - { - await server.StartAsync(new MqttServerOptions()); - - var client1 = new MqttFactory().CreateMqttClient(); - var client2 = new MqttFactory().CreateMqttClient(); - - var options = new MqttClientOptionsBuilder() - .WithTcpServer("localhost") - .WithCleanSession(false) - .WithClientId("a").Build(); - - await client1.ConnectAsync(options); - - await Task.Delay(500); - - await client2.ConnectAsync(options); - - await Task.Delay(500); - - Assert.IsFalse(client1.IsConnected); - Assert.IsTrue(client2.IsConnected); - } - finally - { - await server.StopAsync(); - } - } - - [TestMethod] - public async Task MqttServer_No_Messages_If_No_Subscription() - { - var server = new MqttFactory().CreateMqttServer(); - try - { - await server.StartAsync(new MqttServerOptions()); - - var client = new MqttFactory().CreateMqttClient(); - var receivedMessages = new List(); - - var options = new MqttClientOptionsBuilder() - .WithTcpServer("localhost").Build(); - - client.Connected += async (s, e) => - { - await client.PublishAsync("Connected"); - }; - - client.ApplicationMessageReceived += (s, e) => - { - lock (receivedMessages) - { - receivedMessages.Add(e.ApplicationMessage); - } - }; - - await client.ConnectAsync(options); - - await Task.Delay(500); - - await client.PublishAsync("Hello"); - - await Task.Delay(500); - - Assert.AreEqual(0, receivedMessages.Count); - } - finally - { - await server.StopAsync(); - } - } - - [TestMethod] - public async Task MqttServer_Set_Subscription_At_Server() - { - var server = new MqttFactory().CreateMqttServer(); - try - { - await server.StartAsync(new MqttServerOptions()); - server.ClientConnected += async (s, e) => - { - await server.SubscribeAsync(e.ClientId, "topic1"); - }; - - var client = new MqttFactory().CreateMqttClient(); - var receivedMessages = new List(); - - var options = new MqttClientOptionsBuilder() - .WithTcpServer("localhost").Build(); - - client.ApplicationMessageReceived += (s, e) => - { - lock (receivedMessages) - { - receivedMessages.Add(e.ApplicationMessage); - } - }; - - await client.ConnectAsync(options); - - await Task.Delay(500); - - await client.PublishAsync("Hello"); - - await Task.Delay(500); - - Assert.AreEqual(0, receivedMessages.Count); - } - finally - { - await server.StopAsync(); - } - } - - private static async Task Publish(IMqttClient c1, MqttApplicationMessage message) - { - for (int i = 0; i < 1000; i++) - { - await c1.PublishAsync(message); - } - } - - [TestMethod] - public async Task MqttServer_Shutdown_Disconnects_Clients_Gracefully() - { - using (var testSetup = new TestSetup()) - { - var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); - - var disconnectCalled = 0; - - var c1 = await testSetup.ConnectClientAsync(new MqttClientOptionsBuilder()); - c1.Disconnected += (sender, args) => disconnectCalled++; - - await Task.Delay(100); - - await server.StopAsync(); - - await Task.Delay(100); - - Assert.AreEqual(1, disconnectCalled); - } - } - - [TestMethod] - public async Task MqttServer_Handle_Clean_Disconnect() - { - using (var testSetup = new TestSetup()) - { - var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); - - var clientConnectedCalled = 0; - var clientDisconnectedCalled = 0; - - server.ClientConnected += (_, __) => Interlocked.Increment(ref clientConnectedCalled); - server.ClientDisconnected += (_, __) => Interlocked.Increment(ref clientDisconnectedCalled); - - var c1 = await testSetup.ConnectClientAsync(new MqttClientOptionsBuilder()); - - Assert.AreEqual(1, clientConnectedCalled); - Assert.AreEqual(0, clientDisconnectedCalled); - - await Task.Delay(500); - - await c1.DisconnectAsync(); - - await Task.Delay(500); - - Assert.AreEqual(1, clientConnectedCalled); - Assert.AreEqual(1, clientDisconnectedCalled); - } - } - - [TestMethod] - public async Task MqttServer_Client_Disconnect_Without_Errors() - { - using (var testSetup = new TestSetup()) - { - bool clientWasConnected; - - var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); - try - { - var client = await testSetup.ConnectClientAsync(new MqttClientOptionsBuilder()); - - clientWasConnected = true; - - await client.DisconnectAsync(); - - await Task.Delay(500); - } - finally - { - await server.StopAsync(); - } - - Assert.IsTrue(clientWasConnected); - - testSetup.ThrowIfLogErrors(); - } - } - - [TestMethod] - public async Task MqttServer_Lots_Of_Retained_Messages() - { - const int ClientCount = 25; - - using (var testSetup = new TestSetup()) - { - var server = await testSetup.StartServerAsync(); - - var tasks = new ConcurrentBag(); - for (var i = 0; i < ClientCount; i++) - { - var clientId = i; - tasks.Add(Task.Run(async () => - { - try - { - using (var client = await testSetup.ConnectClientAsync()) - { - // Clear retained message. - await client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + clientId) - .WithPayload(new byte[0]).WithRetainFlag().Build()); - - // Set retained message. - await client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + clientId) - .WithPayload("value").WithRetainFlag().Build()); - - await Task.Delay(10); - - await client.DisconnectAsync(); - } - } - catch (Exception exception) - { - testSetup.TrackException(exception); - } - })); - } - - await Task.WhenAll(tasks.ToArray()); - - await Task.Delay(1000); - - var retainedMessages = server.GetRetainedMessages(); - - Assert.AreEqual(ClientCount, retainedMessages.Count); - - for (var i = 0; i < ClientCount; i++) - { - Assert.IsTrue(retainedMessages.Any(m => m.Topic == "r" + i)); - } - } - } - - [TestMethod] - public async Task MqttServer_Retained_Messages_Flow() - { - var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient("c1"); - await c1.PublishAsync(retainedMessage); - await Task.Delay(500); - await c1.DisconnectAsync(); - await Task.Delay(500); - - var receivedMessages = 0; - var c2 = await serverAdapter.ConnectTestClient("c2"); - c2.ApplicationMessageReceived += (_, e) => - { - receivedMessages++; - }; - - for (var i = 0; i < 5; i++) - { - await c2.UnsubscribeAsync("r"); - await Task.Delay(500); - Assert.AreEqual(i, receivedMessages); - - await c2.SubscribeAsync("r"); - await Task.Delay(500); - Assert.AreEqual(i + 1, receivedMessages); - } - - await c2.DisconnectAsync(); - - await s.StopAsync(); - } - - [TestMethod] - public async Task MqttServer_No_Retained_Message() - { - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var receivedMessagesCount = 0; - - try - { - await s.StartAsync(new MqttServerOptions()); - - var c1 = await serverAdapter.ConnectTestClient("c1"); - await c1.PublishAsync(builder => builder.WithTopic("retained").WithPayload(new byte[3])); - await c1.DisconnectAsync(); - - var c2 = await serverAdapter.ConnectTestClient("c2"); - c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); - - await Task.Delay(500); - } - finally - { - await s.StopAsync(); - } - - Assert.AreEqual(0, receivedMessagesCount); - } - - [TestMethod] - public async Task MqttServer_Retained_Message() - { - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var receivedMessages = new List(); - try - { - await s.StartAsync(new MqttServerOptions()); - - var c1 = await serverAdapter.ConnectTestClient("c1"); - await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); - await c1.DisconnectAsync(); - - var c2 = await serverAdapter.ConnectTestClient("c2"); - c2.ApplicationMessageReceived += (_, e) => - { - lock (receivedMessages) - { - receivedMessages.Add(e.ApplicationMessage); - } - }; - - await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); - - await Task.Delay(500); - } - finally - { - await s.StopAsync(); - } - - Assert.AreEqual(1, receivedMessages.Count); - Assert.IsTrue(receivedMessages.First().Retain); - } - - [TestMethod] - public async Task MqttServer_Clear_Retained_Message() - { - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var receivedMessagesCount = 0; - try - { - await s.StartAsync(new MqttServerOptions()); - - var c1 = await serverAdapter.ConnectTestClient("c1"); - await c1.PublishAsync(builder => builder.WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag()); - await c1.PublishAsync(builder => builder.WithTopic("retained").WithPayload(new byte[0]).WithRetainFlag()); - await c1.DisconnectAsync(); - - var c2 = await serverAdapter.ConnectTestClient("c2"); - c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - - await Task.Delay(200); - await c2.SubscribeAsync(new TopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); - await Task.Delay(500); - } - finally - { - await s.StopAsync(); - } - - - Assert.AreEqual(0, receivedMessagesCount); - } - - [TestMethod] - public async Task MqttServer_Persist_Retained_Message() - { - var storage = new TestStorage(); - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - try - { - var options = new MqttServerOptions { Storage = storage }; - - await s.StartAsync(options); - - var c1 = await serverAdapter.ConnectTestClient("c1"); - - await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); - - await Task.Delay(250); - - await c1.DisconnectAsync(); - } - finally - { - await s.StopAsync(); - } - - Assert.AreEqual(1, storage.Messages.Count); - - s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var receivedMessagesCount = 0; - try - { - var options = new MqttServerOptions { Storage = storage }; - await s.StartAsync(options); - - var c2 = await serverAdapter.ConnectTestClient("c2"); - c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); - - await Task.Delay(250); - } - finally - { - await s.StopAsync(); - } - - Assert.AreEqual(1, receivedMessagesCount); - } - - [TestMethod] - public async Task MqttServer_Intercept_Message() - { - void Interceptor(MqttApplicationMessageInterceptorContext context) - { - context.ApplicationMessage.Payload = Encoding.ASCII.GetBytes("extended"); - } - - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - try - { - var options = new MqttServerOptions { ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(c => Interceptor(c)) }; - - await s.StartAsync(options); - - var c1 = await serverAdapter.ConnectTestClient("c1"); - var c2 = await serverAdapter.ConnectTestClient("c2"); - await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("test").Build()); - - var isIntercepted = false; - c2.ApplicationMessageReceived += (sender, args) => - { - isIntercepted = string.Compare("extended", Encoding.UTF8.GetString(args.ApplicationMessage.Payload), StringComparison.Ordinal) == 0; - }; - - await c1.PublishAsync(builder => builder.WithTopic("test")); - await c1.DisconnectAsync(); - - await Task.Delay(500); - - Assert.IsTrue(isIntercepted); - } - finally - { - await s.StopAsync(); - } - } - - [TestMethod] - public async Task MqttServer_Body() - { - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var bodyIsMatching = false; - try - { - await s.StartAsync(new MqttServerOptions()); - - var c1 = await serverAdapter.ConnectTestClient("c1"); - var c2 = await serverAdapter.ConnectTestClient("c2"); - - c1.ApplicationMessageReceived += (_, e) => - { - if (Encoding.UTF8.GetString(e.ApplicationMessage.Payload) == "The body") - { - bodyIsMatching = true; - } - }; - - await c1.SubscribeAsync("A", MqttQualityOfServiceLevel.AtMostOnce); - await c2.PublishAsync(builder => builder.WithTopic("A").WithPayload(Encoding.UTF8.GetBytes("The body"))); - - await Task.Delay(1000); - } - finally - { - await s.StopAsync(); - } - - Assert.IsTrue(bodyIsMatching); - } - - [TestMethod] - public async Task MqttServer_Connection_Denied() - { - var server = new MqttFactory().CreateMqttServer(); - var client = new MqttFactory().CreateMqttClient(); - - try - { - var options = new MqttServerOptionsBuilder().WithConnectionValidator(context => - { - context.ReturnCode = MqttConnectReturnCode.ConnectionRefusedNotAuthorized; - }).Build(); - - await server.StartAsync(options); - - - var clientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("localhost").Build(); - - try - { - await client.ConnectAsync(clientOptions); - Assert.Fail("An exception should be raised."); - } - catch (Exception exception) - { - if (exception is MqttConnectingFailedException) - { - - } - else - { - Assert.Fail("Wrong exception."); - } - } - } - finally - { - await client.DisconnectAsync(); - await server.StopAsync(); - - client.Dispose(); - } - } - - [TestMethod] - public async Task MqttServer_Same_Client_Id_Connect_Disconnect_Event_Order() - { - using (var testSetup = new TestSetup()) - { - var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); - - var events = new List(); - - server.ClientConnected += (_, __) => - { - lock (events) - { - events.Add("c"); - } - }; - - server.ClientDisconnected += (_, __) => - { - lock (events) - { - events.Add("d"); - } - }; - - var clientOptions = new MqttClientOptionsBuilder() - .WithClientId("same_id"); - - // c - var c1 = await testSetup.ConnectClientAsync(clientOptions); - - await Task.Delay(500); - - var flow = string.Join(string.Empty, events); - Assert.AreEqual("c", flow); - - // dc - var c2 = await testSetup.ConnectClientAsync(clientOptions); - - await Task.Delay(500); - - flow = string.Join(string.Empty, events); - Assert.AreEqual("cdc", flow); - - // nothing - await c1.DisconnectAsync(); - - await Task.Delay(500); - - // d - await c2.DisconnectAsync(); - - await Task.Delay(500); - - await server.StopAsync(); - - flow = string.Join(string.Empty, events); - Assert.AreEqual("cdcd", flow); - } - } - - [TestMethod] - public async Task MqttServer_Remove_Session() - { - using (var testSetup = new TestSetup()) - { - var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); - - var clientOptions = new MqttClientOptionsBuilder(); - var c1 = await testSetup.ConnectClientAsync(clientOptions); - await Task.Delay(500); - Assert.AreEqual(1, (await server.GetClientSessionsStatusAsync()).Count); - - await c1.DisconnectAsync(); - await Task.Delay(500); - - Assert.AreEqual(0, (await server.GetClientSessionsStatusAsync()).Count); - } - } - - [TestMethod] - public async Task MqttServer_Stop_And_Restart() - { - var server = new MqttFactory().CreateMqttServer(); - await server.StartAsync(new MqttServerOptions()); - - var client = new MqttFactory().CreateMqttClient(); - await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost").Build()); - await server.StopAsync(); - - try - { - var client2 = new MqttFactory().CreateMqttClient(); - await client2.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost").Build()); - - Assert.Fail("Connecting should fail."); - } - catch (Exception) - { - } - - await server.StartAsync(new MqttServerOptions()); - var client3 = new MqttFactory().CreateMqttClient(); - await client3.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost").Build()); - - await server.StopAsync(); - } - - [TestMethod] - public async Task MqttServer_Close_Idle_Connection() - { - var server = new MqttFactory().CreateMqttServer(); - - try - { - await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build()); - - var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await client.ConnectAsync("localhost", 1883); - - // Don't send anything. The server should close the connection. - await Task.Delay(TimeSpan.FromSeconds(5)); - - try - { - await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); - Assert.Fail("Receive should throw an exception."); - } - catch (SocketException) - { - } - } - finally - { - await server.StopAsync(); - } - } - - [TestMethod] - public async Task MqttServer_Send_Garbage() - { - var server = new MqttFactory().CreateMqttServer(); - - try - { - await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build()); - - var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await client.ConnectAsync("localhost", 1883); - await client.SendAsync(Encoding.UTF8.GetBytes("Garbage"), SocketFlags.None); - - await Task.Delay(TimeSpan.FromSeconds(5)); - - try - { - await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); - Assert.Fail("Receive should throw an exception."); - } - catch (SocketException) - { - } - } - finally - { - await server.StopAsync(); - } - } - - private class TestStorage : IMqttServerStorage - { - public IList Messages = new List(); - - public Task SaveRetainedMessagesAsync(IList messages) - { - Messages = messages; - return Task.CompletedTask; - } - - public Task> LoadRetainedMessagesAsync() - { - return Task.FromResult(Messages); - } - } - - private static async Task TestPublishAsync( - string topic, - MqttQualityOfServiceLevel qualityOfServiceLevel, - string topicFilter, - MqttQualityOfServiceLevel filterQualityOfServiceLevel, - int expectedReceivedMessagesCount) - { - //using (var testSetup = new TestSetup()) - //{ - // var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); - - // var clientOptions = new MqttClientOptionsBuilder(); - // var c1 = await testSetup.ConnectClientAsync(clientOptions); - // await Task.Delay(500); - // Assert.AreEqual(1, (await server.GetClientSessionsStatusAsync()).Count); - - // await c1.DisconnectAsync(); - // await Task.Delay(500); - - // Assert.AreEqual(0, (await server.GetClientSessionsStatusAsync()).Count); - //} - - - - var s = new MqttFactory().CreateMqttServer(); - - var receivedMessagesCount = 0; - try - { - await s.StartAsync(new MqttServerOptions()); - - var c1 = new MqttFactory().CreateMqttClient(); - c1.UseReceivedApplicationMessageHandler(c => receivedMessagesCount++); - - await c1.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("broker.hivemq.com").Build()); - await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic(topicFilter).WithQualityOfServiceLevel(filterQualityOfServiceLevel).Build()); - - var c2 = new MqttFactory().CreateMqttClient(); - await c2.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("broker.hivemq.com").Build()); - await c2.PublishAsync(builder => builder.WithTopic(topic).WithPayload(new byte[0]).WithQualityOfServiceLevel(qualityOfServiceLevel)); - await c2.DisconnectAsync().ConfigureAwait(false); - - await Task.Delay(500); - await c1.UnsubscribeAsync(topicFilter); - await Task.Delay(500); - } - finally - { - await s.StopAsync(); - } - - Assert.AreEqual(expectedReceivedMessagesCount, receivedMessagesCount); - } - } -} diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs index 7b7e959..8bfae0b 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServerEventDispatcher()); + var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); @@ -26,7 +26,7 @@ namespace MQTTnet.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServerEventDispatcher()); + var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); @@ -41,7 +41,7 @@ namespace MQTTnet.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServerEventDispatcher()); + var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); @@ -57,7 +57,7 @@ namespace MQTTnet.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServerEventDispatcher()); + var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); @@ -70,7 +70,7 @@ namespace MQTTnet.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServerEventDispatcher()); + var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); @@ -81,7 +81,7 @@ namespace MQTTnet.Tests var up = new MqttUnsubscribePacket(); up.TopicFilters.Add("A/B/C"); - sm.Unsubscribe(up); + sm.UnsubscribeAsync(up); Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } diff --git a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs new file mode 100644 index 0000000..4822bea --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs @@ -0,0 +1,159 @@ +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Threading.Tasks; +using MQTTnet.Client.Options; +using MQTTnet.Tests.Mockups; +using MQTTnet.Client; +using MQTTnet.Protocol; +using MQTTnet.Server; + +namespace MQTTnet.Tests +{ + [TestClass] + public class Server_Status_Tests + { + [TestMethod] + public async Task Show_Client_And_Session_Statistics() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(); + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1")); + var c2 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client2")); + + await Task.Delay(500); + + var clientStatus = await server.GetClientStatusAsync(); + var sessionStatus = await server.GetSessionStatusAsync(); + + Assert.AreEqual(2, clientStatus.Count); + Assert.AreEqual(2, sessionStatus.Count); + + Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1")); + Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client2")); + + await c1.DisconnectAsync(); + await c2.DisconnectAsync(); + + await Task.Delay(500); + + clientStatus = await server.GetClientStatusAsync(); + sessionStatus = await server.GetSessionStatusAsync(); + + Assert.AreEqual(0, clientStatus.Count); + Assert.AreEqual(0, sessionStatus.Count); + } + } + + [TestMethod] + public async Task Disconnect_Client() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(); + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1")); + + await Task.Delay(500); + + var clientStatus = await server.GetClientStatusAsync(); + + Assert.AreEqual(1, clientStatus.Count); + + Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1")); + + await clientStatus.First().DisconnectAsync(); + + await Task.Delay(500); + + Assert.IsFalse(c1.IsConnected); + + clientStatus = await server.GetClientStatusAsync(); + + Assert.AreEqual(0, clientStatus.Count); + } + } + + [TestMethod] + public async Task Keep_Persistent_Session() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1")); + var c2 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client2")); + + await c1.DisconnectAsync(); + + await Task.Delay(500); + + var clientStatus = await server.GetClientStatusAsync(); + var sessionStatus = await server.GetSessionStatusAsync(); + + Assert.AreEqual(1, clientStatus.Count); + Assert.AreEqual(2, sessionStatus.Count); + + await c2.DisconnectAsync(); + + await Task.Delay(500); + + clientStatus = await server.GetClientStatusAsync(); + sessionStatus = await server.GetSessionStatusAsync(); + + Assert.AreEqual(0, clientStatus.Count); + Assert.AreEqual(2, sessionStatus.Count); + } + } + + [TestMethod] + public async Task Track_Sent_Application_Messages() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); + + var c1 = await testEnvironment.ConnectClientAsync(); + + for (var i = 1; i < 25; i++) + { + await c1.PublishAsync("a"); + await Task.Delay(50); + + var clientStatus = await server.GetClientStatusAsync(); + Assert.AreEqual(i, clientStatus.First().SentApplicationMessagesCount); + Assert.AreEqual(0, clientStatus.First().ReceivedApplicationMessagesCount); + } + } + } + + [TestMethod] + public async Task Track_Sent_Packets() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithNoKeepAlive()); + + for (var i = 1; i < 25; i++) + { + // At most once will send one packet to the client and the server will reply + // with an additional ACK packet. + await c1.PublishAsync("a", string.Empty, MqttQualityOfServiceLevel.AtLeastOnce); + await Task.Delay(50); + + var clientStatus = await server.GetClientStatusAsync(); + + Assert.AreEqual(i, clientStatus.First().SentApplicationMessagesCount, "SAMC invalid!"); + Assert.AreEqual(i, clientStatus.First().SentPacketsCount, "SPC invalid!"); + + // +1 because ConnACK package is already counted. + Assert.AreEqual(i + 1, clientStatus.First().ReceivedPacketsCount, "RPC invalid!"); + } + } + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs new file mode 100644 index 0000000..b9efc7a --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -0,0 +1,970 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Adapter; +using MQTTnet.Client; +using MQTTnet.Client.Connecting; +using MQTTnet.Client.Options; +using MQTTnet.Protocol; +using MQTTnet.Server; +using MQTTnet.Tests.Mockups; + +namespace MQTTnet.Tests +{ + [TestClass] + public class Server_Tests + { + [TestMethod] + public async Task Publish_At_Most_Once_0x00() + { + await TestPublishAsync( + "A/B/C", + MqttQualityOfServiceLevel.AtMostOnce, + "A/B/C", + MqttQualityOfServiceLevel.AtMostOnce, + 1); + } + + [TestMethod] + public async Task Publish_At_Least_Once_0x01() + { + await TestPublishAsync( + "A/B/C", + MqttQualityOfServiceLevel.AtLeastOnce, + "A/B/C", + MqttQualityOfServiceLevel.AtLeastOnce, + 1); + } + + [TestMethod] + public async Task Publish_Exactly_Once_0x02() + { + await TestPublishAsync( + "A/B/C", + MqttQualityOfServiceLevel.ExactlyOnce, + "A/B/C", + MqttQualityOfServiceLevel.ExactlyOnce, + 1); + } + + [TestMethod] + public async Task Will_Message_Do_Not_Send() + { + using (var testEnvironment = new TestEnvironment()) + { + var receivedMessagesCount = 0; + + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + + var willMessage = new MqttApplicationMessageBuilder().WithTopic("My/last/will").WithAtMostOnceQoS().Build(); + + var clientOptions = new MqttClientOptionsBuilder().WithWillMessage(willMessage); + + var c1 = await testEnvironment.ConnectClientAsync(); + c1.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); + await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic("#").Build()); + + var c2 = await testEnvironment.ConnectClientAsync(clientOptions); + await c2.DisconnectAsync().ConfigureAwait(false); + + await Task.Delay(1000); + + Assert.AreEqual(0, receivedMessagesCount); + } + } + + [TestMethod] + public async Task Will_Message_Send() + { + using (var testEnvironment = new TestEnvironment()) + { + var receivedMessagesCount = 0; + + await testEnvironment.StartServerAsync(); + + var willMessage = new MqttApplicationMessageBuilder().WithTopic("My/last/will").WithAtMostOnceQoS().Build(); + + var clientOptions = new MqttClientOptionsBuilder().WithWillMessage(willMessage); + + var c1 = await testEnvironment.ConnectClientAsync(); + c1.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); + await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic("#").Build()); + + var c2 = await testEnvironment.ConnectClientAsync(clientOptions); + c2.Dispose(); // Dispose will not send a DISCONNECT pattern first so the will message must be sent. + + await Task.Delay(1000); + + Assert.AreEqual(1, receivedMessagesCount); + } + } + + [TestMethod] + public async Task Subscribe_Unsubscribe() + { + using (var testEnvironment = new TestEnvironment()) + { + var receivedMessagesCount = 0; + + var server = await testEnvironment.StartServerAsync(); + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("c1")); + c1.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); + + var c2 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("c2")); + + var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); + await c2.PublishAsync(message); + + await Task.Delay(500); + Assert.AreEqual(0, receivedMessagesCount); + + var subscribeEventCalled = false; + server.ClientSubscribedTopic += (_, e) => + { + subscribeEventCalled = e.TopicFilter.Topic == "a" && e.ClientId == "c1"; + }; + + await c1.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); + await Task.Delay(250); + Assert.IsTrue(subscribeEventCalled, "Subscribe event not called."); + + await c2.PublishAsync(message); + await Task.Delay(250); + Assert.AreEqual(1, receivedMessagesCount); + + var unsubscribeEventCalled = false; + server.ClientUnsubscribedTopic += (_, e) => + { + unsubscribeEventCalled = e.TopicFilter == "a" && e.ClientId == "c1"; + }; + + await c1.UnsubscribeAsync("a"); + await Task.Delay(250); + Assert.IsTrue(unsubscribeEventCalled, "Unsubscribe event not called."); + + await c2.PublishAsync(message); + await Task.Delay(500); + Assert.AreEqual(1, receivedMessagesCount); + + await Task.Delay(500); + + Assert.AreEqual(1, receivedMessagesCount); + } + } + + [TestMethod] + public async Task Publish_From_Server() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(); + + var receivedMessagesCount = 0; + + var client = await testEnvironment.ConnectClientAsync(); + client.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); + + var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); + await client.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); + + await server.PublishAsync(message); + + await Task.Delay(1000); + + Assert.AreEqual(1, receivedMessagesCount); + } + } + + [TestMethod] + public async Task Publish_Multiple_Clients() + { + var receivedMessagesCount = 0; + var locked = new object(); + + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(); + + var c1 = await testEnvironment.ConnectClientAsync(); + var c2 = await testEnvironment.ConnectClientAsync(); + + c1.UseReceivedApplicationMessageHandler(c => + { + lock (locked) + { + receivedMessagesCount++; + } + }); + + c2.UseReceivedApplicationMessageHandler(c => + { + lock (locked) + { + receivedMessagesCount++; + } + }); + + await c1.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); + await c2.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); + + var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); + + for (var i = 0; i < 1000; i++) + { + await c1.PublishAsync(message); + } + + await Task.Delay(500); + + Assert.AreEqual(2000, receivedMessagesCount); + } + } + + [TestMethod] + public async Task Session_Takeover() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(); + + var options = new MqttClientOptionsBuilder() + .WithCleanSession(false) + .WithClientId("a"); + + var client1 = await testEnvironment.ConnectClientAsync(options); + await Task.Delay(500); + + var client2 = await testEnvironment.ConnectClientAsync(options); + await Task.Delay(500); + + Assert.IsFalse(client1.IsConnected); + Assert.IsTrue(client2.IsConnected); + } + } + + [TestMethod] + public async Task No_Messages_If_No_Subscription() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(); + + var client = await testEnvironment.ConnectClientAsync(); + var receivedMessages = new List(); + + client.Connected += async (s, e) => + { + await client.PublishAsync("Connected"); + }; + + client.UseReceivedApplicationMessageHandler(c => + { + lock (receivedMessages) + { + receivedMessages.Add(c.ApplicationMessage); + } + }); + + await Task.Delay(500); + + await client.PublishAsync("Hello"); + + await Task.Delay(500); + + Assert.AreEqual(0, receivedMessages.Count); + } + } + + [TestMethod] + public async Task Set_Subscription_At_Server() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(); + server.ClientConnected += async (s, e) => + { + await server.SubscribeAsync(e.ClientId, "topic1"); + }; + + var client = await testEnvironment.ConnectClientAsync(); + var receivedMessages = new List(); + + client.UseReceivedApplicationMessageHandler(c => + { + lock (receivedMessages) + { + receivedMessages.Add(c.ApplicationMessage); + } + }); + + await Task.Delay(500); + + await client.PublishAsync("Hello"); + await Task.Delay(100); + Assert.AreEqual(0, receivedMessages.Count); + + await client.PublishAsync("topic1"); + await Task.Delay(100); + Assert.AreEqual(1, receivedMessages.Count); + } + } + + [TestMethod] + public async Task Shutdown_Disconnects_Clients_Gracefully() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + + var disconnectCalled = 0; + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder()); + c1.Disconnected += (sender, args) => disconnectCalled++; + + await Task.Delay(100); + + await server.StopAsync(); + + await Task.Delay(100); + + Assert.AreEqual(1, disconnectCalled); + } + } + + [TestMethod] + public async Task Handle_Clean_Disconnect() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + + var clientConnectedCalled = 0; + var clientDisconnectedCalled = 0; + + server.ClientConnected += (_, __) => Interlocked.Increment(ref clientConnectedCalled); + server.ClientDisconnected += (_, __) => Interlocked.Increment(ref clientDisconnectedCalled); + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder()); + + Assert.AreEqual(1, clientConnectedCalled); + Assert.AreEqual(0, clientDisconnectedCalled); + + await Task.Delay(500); + + await c1.DisconnectAsync(); + + await Task.Delay(500); + + Assert.AreEqual(1, clientConnectedCalled); + Assert.AreEqual(1, clientDisconnectedCalled); + } + } + + [TestMethod] + public async Task Client_Disconnect_Without_Errors() + { + using (var testEnvironment = new TestEnvironment()) + { + bool clientWasConnected; + + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + try + { + var client = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder()); + + clientWasConnected = true; + + await client.DisconnectAsync(); + + await Task.Delay(500); + } + finally + { + await server.StopAsync(); + } + + Assert.IsTrue(clientWasConnected); + + testEnvironment.ThrowIfLogErrors(); + } + } + + [TestMethod] + public async Task Handle_Lots_Of_Parallel_Retained_Messages() + { + const int ClientCount = 50; + + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(); + + var tasks = new List(); + for (var i = 0; i < ClientCount; i++) + { + var i2 = i; + var testEnvironment2 = testEnvironment; + + tasks.Add(Task.Run(async () => + { + try + { + using (var client = await testEnvironment2.ConnectClientAsync()) + { + // Clear retained message. + await client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + i2) + .WithPayload(new byte[0]).WithRetainFlag().Build()); + + // Set retained message. + await client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + i2) + .WithPayload("value").WithRetainFlag().Build()); + + await client.DisconnectAsync(); + } + } + catch (Exception exception) + { + testEnvironment2.TrackException(exception); + } + })); + } + + await Task.WhenAll(tasks); + + await Task.Delay(1000); + + var retainedMessages = await server.GetRetainedMessagesAsync(); + + Assert.AreEqual(ClientCount, retainedMessages.Count); + + for (var i = 0; i < ClientCount; i++) + { + Assert.IsTrue(retainedMessages.Any(m => m.Topic == "r" + i)); + } + } + } + + [TestMethod] + public async Task Retained_Messages_Flow() + { + using (var testEnvironment = new TestEnvironment()) + { + var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); + + await testEnvironment.StartServerAsync(); + var c1 = await testEnvironment.ConnectClientAsync(); + + var receivedMessages = 0; + + var c2 = await testEnvironment.ConnectClientAsync(); + c2.UseReceivedApplicationMessageHandler(c => + { + Interlocked.Increment(ref receivedMessages); + }); + + await c1.PublishAsync(retainedMessage); + await c1.DisconnectAsync(); + await Task.Delay(500); + + for (var i = 0; i < 5; i++) + { + await c2.UnsubscribeAsync("r"); + await Task.Delay(100); + Assert.AreEqual(i, receivedMessages); + + await c2.SubscribeAsync("r"); + await Task.Delay(100); + Assert.AreEqual(i + 1, receivedMessages); + } + + await c2.DisconnectAsync(); + } + } + + [TestMethod] + public async Task Receive_No_Retained_Message_After_Subscribe() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(); + + var c1 = await testEnvironment.ConnectClientAsync(); + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); + await c1.DisconnectAsync(); + + var receivedMessagesCount = 0; + + var c2 = await testEnvironment.ConnectClientAsync(); + c2.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); + await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained_other").Build()); + + await Task.Delay(500); + + Assert.AreEqual(0, receivedMessagesCount); + } + } + + [TestMethod] + public async Task Receive_Retained_Message_After_Subscribe() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(); + + var c1 = await testEnvironment.ConnectClientAsync(); + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); + await c1.DisconnectAsync(); + + var receivedMessages = new List(); + + var c2 = await testEnvironment.ConnectClientAsync(); + c2.UseReceivedApplicationMessageHandler(c => + { + lock (receivedMessages) + { + receivedMessages.Add(c.ApplicationMessage); + } + }); + + await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); + + await Task.Delay(500); + + Assert.AreEqual(1, receivedMessages.Count); + Assert.IsTrue(receivedMessages.First().Retain); + } + } + + [TestMethod] + public async Task Clear_Retained_Message() + { + using (var testEnvironment = new TestEnvironment()) + { + var receivedMessagesCount = 0; + + await testEnvironment.StartServerAsync(); + + var c1 = await testEnvironment.ConnectClientAsync(); + + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[0]).WithRetainFlag().Build()); + + await c1.DisconnectAsync(); + + var c2 = await testEnvironment.ConnectClientAsync(); + + c2.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); + + await Task.Delay(200); + await c2.SubscribeAsync(new TopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); + await Task.Delay(500); + + Assert.AreEqual(0, receivedMessagesCount); + } + } + + [TestMethod] + public async Task Persist_Retained_Message() + { + var serverStorage = new TestServerStorage(); + + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithStorage(serverStorage)); + + var c1 = await testEnvironment.ConnectClientAsync(); + + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); + + await Task.Delay(500); + + Assert.AreEqual(1, serverStorage.Messages.Count); + } + } + + [TestMethod] + public async Task Intercept_Message() + { + void Interceptor(MqttApplicationMessageInterceptorContext context) + { + context.ApplicationMessage.Payload = Encoding.ASCII.GetBytes("extended"); + } + + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithApplicationMessageInterceptor(Interceptor)); + + var c1 = await testEnvironment.ConnectClientAsync(); + var c2 = await testEnvironment.ConnectClientAsync(); + await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("test").Build()); + + var isIntercepted = false; + c2.UseReceivedApplicationMessageHandler(c => + { + isIntercepted = string.Compare("extended", Encoding.UTF8.GetString(c.ApplicationMessage.Payload), StringComparison.Ordinal) == 0; + }); + + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("test").Build()); + await c1.DisconnectAsync(); + + await Task.Delay(500); + + Assert.IsTrue(isIntercepted); + } + } + + [TestMethod] + public async Task Send_Long_Body() + { + using (var testEnvironment = new TestEnvironment()) + { + const int PayloadSizeInMB = 30; + const int CharCount = PayloadSizeInMB * 1024 * 1024; + + var longBody = new byte[CharCount]; + byte @char = 32; + + for (long i = 0; i < PayloadSizeInMB * 1024L * 1024L; i++) + { + longBody[i] = @char; + + @char++; + + if (@char > 126) + { + @char = 32; + } + } + + byte[] receivedBody = null; + + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + var client1 = await testEnvironment.ConnectClientAsync(); + client1.UseReceivedApplicationMessageHandler(c => + { + receivedBody = c.ApplicationMessage.Payload; + }); + + await client1.SubscribeAsync("string"); + + var client2 = await testEnvironment.ConnectClientAsync(); + await client2.PublishAsync("string", longBody); + + await Task.Delay(500); + + Assert.IsTrue(longBody.SequenceEqual(receivedBody ?? new byte[0])); + } + } + + [TestMethod] + public async Task Deny_Connection() + { + var serverOptions = new MqttServerOptionsBuilder().WithConnectionValidator(context => + { + context.ReturnCode = MqttConnectReturnCode.ConnectionRefusedNotAuthorized; + }); + + using (var testEnvironment = new TestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + + await testEnvironment.StartServerAsync(serverOptions); + + try + { + await testEnvironment.ConnectClientAsync(); + Assert.Fail("An exception should be raised."); + } + catch (Exception exception) + { + if (exception is MqttConnectingFailedException connectingFailedException) + { + Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode); + } + else + { + Assert.Fail("Wrong exception."); + } + } + } + } + + [TestMethod] + public async Task Same_Client_Id_Connect_Disconnect_Event_Order() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + + var events = new List(); + + server.ClientConnected += (_, __) => + { + lock (events) + { + events.Add("c"); + } + }; + + server.ClientDisconnected += (_, __) => + { + lock (events) + { + events.Add("d"); + } + }; + + var clientOptions = new MqttClientOptionsBuilder() + .WithClientId("same_id"); + + // c + var c1 = await testEnvironment.ConnectClientAsync(clientOptions); + + await Task.Delay(500); + + var flow = string.Join(string.Empty, events); + Assert.AreEqual("c", flow); + + // dc + var c2 = await testEnvironment.ConnectClientAsync(clientOptions); + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cdc", flow); + + // nothing + await c1.DisconnectAsync(); + + await Task.Delay(500); + + // d + await c2.DisconnectAsync(); + + await Task.Delay(500); + + await server.StopAsync(); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cdcd", flow); + } + } + + [TestMethod] + public async Task Remove_Session() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + + var clientOptions = new MqttClientOptionsBuilder(); + var c1 = await testEnvironment.ConnectClientAsync(clientOptions); + await Task.Delay(500); + Assert.AreEqual(1, (await server.GetClientStatusAsync()).Count); + + await c1.DisconnectAsync(); + await Task.Delay(500); + + Assert.AreEqual(0, (await server.GetClientStatusAsync()).Count); + } + } + + [TestMethod] + public async Task Stop_And_Restart() + { + using (var testEnvironment = new TestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + + var server = await testEnvironment.StartServerAsync(); + + await testEnvironment.ConnectClientAsync(); + await server.StopAsync(); + + try + { + await testEnvironment.ConnectClientAsync(); + Assert.Fail("Connecting should fail."); + } + catch (Exception) + { + } + + await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultEndpointPort(testEnvironment.ServerPort).Build()); + await testEnvironment.ConnectClientAsync(); + } + } + + [TestMethod] + public async Task Close_Idle_Connection() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); + + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync("localhost", testEnvironment.ServerPort); + + // Don't send anything. The server should close the connection. + await Task.Delay(TimeSpan.FromSeconds(3)); + + try + { + var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + if (receivedBytes == 0) + { + return; + } + + Assert.Fail("Receive should throw an exception."); + } + catch (SocketException) + { + } + } + } + + [TestMethod] + public async Task Send_Garbage() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); + + // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state + // forever. This is security related. + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync("localhost", testEnvironment.ServerPort); + await client.SendAsync(Encoding.UTF8.GetBytes("Garbage"), SocketFlags.None); + + await Task.Delay(TimeSpan.FromSeconds(3)); + + try + { + var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + if (receivedBytes == 0) + { + return; + } + + Assert.Fail("Receive should throw an exception."); + } + catch (SocketException) + { + } + } + } + + [TestMethod] + public async Task Do_Not_Send_Retained_Messages_For_Denied_Subscription() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithSubscriptionInterceptor(c => + { + // This should lead to no subscriptions for "n" at all. So also no sending of retained messages. + if (c.TopicFilter.Topic == "n") + { + c.AcceptSubscription = false; + } + })); + + // Prepare some retained messages. + var client1 = await testEnvironment.ConnectClientAsync(); + await client1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("y").WithPayload("x").WithRetainFlag().Build()); + await client1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("n").WithPayload("x").WithRetainFlag().Build()); + await client1.DisconnectAsync(); + + await Task.Delay(500); + + // Subscribe to all retained message types. + // It is important to do this in a range of filters to ensure that a subscription is not "hidden". + var client2 = await testEnvironment.ConnectClientAsync(); + + var buffer = new StringBuilder(); + + client2.UseReceivedApplicationMessageHandler(c => + { + lock (buffer) + { + buffer.Append(c.ApplicationMessage.Topic); + } + }); + + await client2.SubscribeAsync(new TopicFilter { Topic = "y" }, new TopicFilter { Topic = "n" }); + + await Task.Delay(500); + + Assert.AreEqual("y", buffer.ToString()); + } + } + + [TestMethod] + public async Task Collect_Messages_In_Disconnected_Session() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); + + // Create the session including the subscription. + var client1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("a")); + await client1.SubscribeAsync("x"); + await client1.DisconnectAsync(); + await Task.Delay(500); + + var clientStatus = await server.GetClientStatusAsync(); + Assert.AreEqual(0, clientStatus.Count); + + var client2 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("b")); + await client2.PublishAsync("x", "1"); + await client2.PublishAsync("x", "2"); + await client2.PublishAsync("x", "3"); + await client2.DisconnectAsync(); + + await Task.Delay(500); + + clientStatus = await server.GetClientStatusAsync(); + var sessionStatus = await server.GetSessionStatusAsync(); + + Assert.AreEqual(0, clientStatus.Count); + Assert.AreEqual(2, sessionStatus.Count); + + Assert.AreEqual(3, sessionStatus.First(s => s.ClientId == "a").PendingApplicationMessagesCount); + } + } + + private static async Task TestPublishAsync( + string topic, + MqttQualityOfServiceLevel qualityOfServiceLevel, + string topicFilter, + MqttQualityOfServiceLevel filterQualityOfServiceLevel, + int expectedReceivedMessagesCount) + { + using (var testEnvironment = new TestEnvironment()) + { + var receivedMessagesCount = 0; + + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); + + var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("receiver")); + c1.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); + await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic(topicFilter).WithQualityOfServiceLevel(filterQualityOfServiceLevel).Build()); + + var c2 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("sender")); + await c2.PublishAsync(new MqttApplicationMessageBuilder().WithTopic(topic).WithPayload(new byte[0]).WithQualityOfServiceLevel(qualityOfServiceLevel).Build()); + await c2.DisconnectAsync().ConfigureAwait(false); + + await Task.Delay(500); + await c1.UnsubscribeAsync(topicFilter); + await Task.Delay(500); + + Assert.AreEqual(expectedReceivedMessagesCount, receivedMessagesCount); + } + } + } +} diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml index bc1cdf4..210f059 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml @@ -5,6 +5,7 @@ xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" xmlns:server="using:MQTTnet.Server" + xmlns:status="using:MQTTnet.Server.Status" d:DesignHeight="800" d:DesignWidth="800" mc:Ignorable="d"> @@ -228,7 +229,7 @@ - +