diff --git a/Build/MQTTnet.Extensions.Rpc.nuspec b/Build/MQTTnet.Extensions.Rpc.nuspec index cb31c21..8bca0b5 100644 --- a/Build/MQTTnet.Extensions.Rpc.nuspec +++ b/Build/MQTTnet.Extensions.Rpc.nuspec @@ -12,7 +12,7 @@ This is a extension library which allows executing synchronous device calls including a response using MQTTnet. * Updated to MQTTnet 2.7.5. - Copyright Christian Kratky 2016-2017 + Copyright Christian Kratky 2016-2018 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin @@ -21,19 +21,15 @@ - - - - @@ -41,7 +37,17 @@ + + + - + + + + + + + + \ No newline at end of file diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 4146cd3..2be9d25 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -2,7 +2,7 @@ MQTTnet - 2.7.5 + 2.8.0 Christian Kratky Christian Kratky https://github.com/chkr1011/MQTTnet/blob/master/LICENSE @@ -10,47 +10,36 @@ https://raw.githubusercontent.com/chkr1011/MQTTnet/master/Images/Logo_128x128.png false MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker). - * [Client] Fixed a deadlock while the client disconnects. -* [Client] Fixed broken support for protocol version 3.1.0. -* [Server] The _MqttTcpServerAdapter_ is now added to the ASP.NET services. -* [Server] _MqttServerAdapter_ is renamed to _MqttTcpServerAdapter_ (BREAKING CHANGE!). -* [Server] The server no longer sends the will message of a client if the disconnect was clean (via _Disconnect_ packet). -* [Server] The application message interceptor now allows closing the connection. -* [Server] Added a new flag for the _ClientDisconnected_ event which contains a value indicating whether the disconnect was clean (via _Disconnect_ packet). + * [Client] Received messages are now processed in the worker thread by default. Added a new setting for switching back to dedicated threads. +* [Server] Added support for other WebSocket sub protocol formats like mqttv-3.1.1 (thanks to @israellot). +* [Server] The takeover of an existing client sessions is now treated as a _clean_ disconnect of the previous client. Copyright Christian Kratky 2016-2018 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin - - - + - - - - - @@ -62,7 +51,6 @@ - - + \ No newline at end of file diff --git a/Build/build.ps1 b/Build/build.ps1 index 021ca1d..021858f 100644 --- a/Build/build.ps1 +++ b/Build/build.ps1 @@ -20,11 +20,11 @@ if ($path) { &$msbuild ..\Frameworks\MQTTnet.AspNetCore\MQTTnet.AspNetCore.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="netstandard2.0" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m /p:SignAssembly=true /p:AssemblyOriginatorKeyFile=".\..\..\Build\codeSigningKey.pfx" # Build the RPC extension - &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="net452" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m - &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="net461" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m - &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="netstandard1.3" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m - &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="netstandard2.0" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m - &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="uap10.0" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m + &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="net452" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m /p:SignAssembly=true /p:AssemblyOriginatorKeyFile=".\..\..\Build\codeSigningKey.pfx" + &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="net461" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m /p:SignAssembly=true /p:AssemblyOriginatorKeyFile=".\..\..\Build\codeSigningKey.pfx" + &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="netstandard1.3" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m /p:SignAssembly=true /p:AssemblyOriginatorKeyFile=".\..\..\Build\codeSigningKey.pfx" + &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="netstandard2.0" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m /p:SignAssembly=true /p:AssemblyOriginatorKeyFile=".\..\..\Build\codeSigningKey.pfx" + &$msbuild ..\Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj /t:Build /p:Configuration="Release" /p:TargetFramework="uap10.0" /p:FileVersion=$assemblyVersion /p:AssemblyVersion=$assemblyVersion /verbosity:m /p:SignAssembly=true /p:AssemblyOriginatorKeyFile=".\..\..\Build\codeSigningKey.pfx" Remove-Item .\NuGet -Force -Recurse -ErrorAction SilentlyContinue diff --git a/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs b/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs index 938fcb4..e824279 100644 --- a/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs +++ b/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs @@ -1,18 +1,18 @@ using System; using System.Collections.Concurrent; +using System.Text; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Client; -using MQTTnet.Internal; +using MQTTnet.Exceptions; using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { public sealed class MqttRpcClient : IDisposable { - private const string ResponseTopic = "$MQTTnet.RPC/+/+/response"; private readonly ConcurrentDictionary> _waitingCalls = new ConcurrentDictionary>(); private readonly IMqttClient _mqttClient; - private bool _isEnabled; public MqttRpcClient(IMqttClient mqttClient) { @@ -21,19 +21,22 @@ namespace MQTTnet.Extensions.Rpc _mqttClient.ApplicationMessageReceived += OnApplicationMessageReceived; } - public async Task EnableAsync() + public Task ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) { - await _mqttClient.SubscribeAsync(new TopicFilterBuilder().WithTopic(ResponseTopic).WithAtLeastOnceQoS().Build()); - _isEnabled = true; + return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None); } - public async Task DisableAsync() + public Task ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) { - await _mqttClient.UnsubscribeAsync(ResponseTopic); - _isEnabled = false; + return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken); } - public async Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) + public Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) + { + return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None); + } + + public async Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) { if (methodName == null) throw new ArgumentNullException(nameof(methodName)); @@ -42,12 +45,7 @@ namespace MQTTnet.Extensions.Rpc throw new ArgumentException("The method name cannot contain /, + or #."); } - if (!_isEnabled) - { - throw new InvalidOperationException("The RPC client is not enabled."); - } - - var requestTopic = $"$MQTTnet.RPC/{Guid.NewGuid():N}/{methodName}"; + var requestTopic = $"MQTTnet.RPC/{Guid.NewGuid():N}/{methodName}"; var responseTopic = requestTopic + "/response"; var requestMessage = new MqttApplicationMessageBuilder() @@ -64,18 +62,49 @@ namespace MQTTnet.Extensions.Rpc throw new InvalidOperationException(); } - await _mqttClient.PublishAsync(requestMessage); - return await tcs.Task.TimeoutAfter(timeout); + await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false); + await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false); + + using (var timeoutCts = new CancellationTokenSource(timeout)) + using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token)) + { + linkedCts.Token.Register(() => + { + if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled) + { + tcs.TrySetCanceled(); + } + }); + + try + { + var result = await tcs.Task.ConfigureAwait(false); + timeoutCts.Cancel(false); + return result; + } + catch (TaskCanceledException taskCanceledException) + { + if (timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) + { + throw new MqttCommunicationTimedOutException(taskCanceledException); + } + else + { + throw; + } + } + } } finally { _waitingCalls.TryRemove(responseTopic, out _); + await _mqttClient.UnsubscribeAsync(responseTopic).ConfigureAwait(false); } } private void OnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs eventArgs) { - if (!_waitingCalls.TryRemove(eventArgs.ApplicationMessage.Topic, out TaskCompletionSource tcs)) + if (!_waitingCalls.TryRemove(eventArgs.ApplicationMessage.Topic, out var tcs)) { return; } diff --git a/Extensions/MQTTnet.Extensions.Rpc/SampleCCode.c b/Extensions/MQTTnet.Extensions.Rpc/SampleCCode.c index 5f28270..c20aa4a 100644 --- a/Extensions/MQTTnet.Extensions.Rpc/SampleCCode.c +++ b/Extensions/MQTTnet.Extensions.Rpc/SampleCCode.c @@ -1 +1,29 @@ - \ No newline at end of file +// If using the MQTT client PubSubClient it must be ensured that the request topic for each method is subscribed like the following. +_mqttClient.subscribe("MQTTnet.RPC/+/ping"); +_mqttClient.subscribe("MQTTnet.RPC/+/do_something"); + +// It is not allowed to change the structure of the topic. Otherwise RPC will not work. So method names can be separated using +// an _ or . but no +, # or . If it is required to distinguish between devices own rules can be defined like the following. +_mqttClient.subscribe("MQTTnet.RPC/+/deviceA.ping"); +_mqttClient.subscribe("MQTTnet.RPC/+/deviceB.ping"); +_mqttClient.subscribe("MQTTnet.RPC/+/deviceC.getTemperature"); + +// Within the callback of the MQTT client the topic must be checked if it belongs to MQTTnet RPC. The following code shows one +// possible way of doing this. +void mqtt_Callback(char *topic, byte *payload, unsigned int payloadLength) +{ + String topicString = String(topic); + + if (topicString.startsWith("MQTTnet.RPC/")) { + String responseTopic = topicString + String("/response"); + + if (topicString.endsWith("/deviceA.ping")) { + mqtt_publish(responseTopic, "pong", false); + return; + } + } +} + +// Important notes: +// ! Do not send response message with the _retain_ flag set to true. +// ! All required data for a RPC call and the result must be placed into the payload. \ No newline at end of file diff --git a/Frameworks/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Frameworks/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index 5a25712..87b2038 100644 --- a/Frameworks/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Frameworks/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -17,12 +17,12 @@ namespace MQTTnet.AspNetCore { string subprotocol = null; - if (context.Request.Headers.TryGetValue("Sec-WebSocket-Protocol", out var requestedSubProtocolValues) - && requestedSubProtocolValues.Count > 0 - && requestedSubProtocolValues.Any(v => v.ToLower() == "mqtt") - ) + if (context.Request.Headers.TryGetValue("Sec-WebSocket-Protocol", out var requestedSubProtocolValues)) { - subprotocol = "mqtt"; + // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. + subprotocol = requestedSubProtocolValues + .OrderByDescending(p => p.Length) + .FirstOrDefault(p => p.ToLower().StartsWith("mqtt")); } var adapter = app.ApplicationServices.GetRequiredService(); diff --git a/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs b/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs index c84b9be..babc59b 100644 --- a/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs +++ b/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs @@ -3,12 +3,13 @@ using System.Net.WebSockets; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; +using MQTTnet.Implementations; using MQTTnet.Serializer; using MQTTnet.Server; namespace MQTTnet.AspNetCore { - public sealed class MqttWebSocketServerAdapter : IMqttServerAdapter, IDisposable + public class MqttWebSocketServerAdapter : IMqttServerAdapter { public event EventHandler ClientAccepted; @@ -26,8 +27,7 @@ namespace MQTTnet.AspNetCore { if (webSocket == null) throw new ArgumentNullException(nameof(webSocket)); - var channel = new MqttWebSocketServerChannel(webSocket); - var clientAdapter = new MqttChannelAdapter(channel, new MqttPacketSerializer(), new MqttNetLogger()); + var clientAdapter = new MqttChannelAdapter(new MqttWebSocketChannel(webSocket), new MqttPacketSerializer(), new MqttNetLogger()); var eventArgs = new MqttServerAdapterClientAcceptedEventArgs(clientAdapter); ClientAccepted?.Invoke(this, eventArgs); diff --git a/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerChannel.cs b/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerChannel.cs deleted file mode 100644 index 7325a5f..0000000 --- a/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerChannel.cs +++ /dev/null @@ -1,60 +0,0 @@ -using System; -using System.IO; -using System.Net.WebSockets; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Channel; -using MQTTnet.Implementations; - -namespace MQTTnet.AspNetCore -{ - public class MqttWebSocketServerChannel : IMqttChannel, IDisposable - { - private WebSocket _webSocket; - - public MqttWebSocketServerChannel(WebSocket webSocket) - { - _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); - - SendStream = new WebSocketStream(_webSocket); - ReceiveStream = SendStream; - } - - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } - - public Task ConnectAsync() - { - return Task.CompletedTask; - } - - public async Task DisconnectAsync() - { - if (_webSocket == null) - { - return; - } - - try - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); - } - finally - { - Dispose(); - } - } - - public void Dispose() - { - SendStream?.Dispose(); - ReceiveStream?.Dispose(); - - _webSocket?.Dispose(); - - SendStream = null; - ReceiveStream = null; - _webSocket = null; - } - } -} \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs index 7a5c02a..0970aff 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs @@ -11,11 +11,11 @@ namespace MQTTnet.Adapter { IMqttPacketSerializer PacketSerializer { get; } - Task ConnectAsync(TimeSpan timeout); + Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken); - Task DisconnectAsync(TimeSpan timeout); + Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken); - Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets); + Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets, CancellationToken cancellationToken); Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); } diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttServerAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttServerAdapter.cs index bd583f5..eff9eab 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttServerAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttServerAdapter.cs @@ -4,7 +4,7 @@ using MQTTnet.Server; namespace MQTTnet.Adapter { - public interface IMqttServerAdapter + public interface IMqttServerAdapter : IDisposable { event EventHandler ClientAccepted; diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index c2f9ba4..c531add 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -8,7 +8,6 @@ using System.Threading.Tasks; using MQTTnet.Channel; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; -using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Serializer; @@ -19,11 +18,11 @@ namespace MQTTnet.Adapter private const uint ErrorOperationAborted = 0x800703E3; private const int ReadBufferSize = 4096; // TODO: Move buffer size to config - private bool _isDisposed; - private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly IMqttNetLogger _logger; private readonly IMqttChannel _channel; + private bool _isDisposed; + public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); @@ -33,75 +32,52 @@ namespace MQTTnet.Adapter public IMqttPacketSerializer PacketSerializer { get; } - public Task ConnectAsync(TimeSpan timeout) + public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); _logger.Verbose("Connecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync().TimeoutAfter(timeout)); + return ExecuteAndWrapExceptionAsync(() => + Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); } - public Task DisconnectAsync(TimeSpan timeout) + public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); _logger.Verbose("Disconnecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); + return ExecuteAndWrapExceptionAsync(() => + Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, cancellationToken)); } - public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) + public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets, CancellationToken cancellationToken) { ThrowIfDisposed(); - return ExecuteAndWrapExceptionAsync(async () => + foreach (var packet in packets) { - await _semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - try + if (packet == null) { - foreach (var packet in packets) - { - if (cancellationToken.IsCancellationRequested) - { - return; - } - - if (packet == null) - { - continue; - } - - _logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); - - var chunks = PacketSerializer.Serialize(packet); - foreach (var chunk in chunks) - { - if (cancellationToken.IsCancellationRequested) - { - return; - } - - await _channel.SendStream.WriteAsync(chunk.Array, chunk.Offset, chunk.Count, cancellationToken).ConfigureAwait(false); - } - } + continue; + } - if (cancellationToken.IsCancellationRequested) - { - return; - } + await SendPacketAsync(timeout, cancellationToken, packet).ConfigureAwait(false); + } + } - if (timeout > TimeSpan.Zero) - { - await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); - } - else - { - await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); - } - } - finally - { - _semaphore.Release(); - } + private Task SendPacketAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) + { + return ExecuteAndWrapExceptionAsync(() => + { + _logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); + + var packetData = PacketSerializer.Serialize(packet); + + return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( + packetData.Array, + packetData.Offset, + packetData.Count, + ct), timeout, cancellationToken); }); } @@ -117,11 +93,11 @@ namespace MQTTnet.Adapter { if (timeout > TimeSpan.Zero) { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); + receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); } else { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); + receivedMqttPacket = await ReceiveAsync(_channel, cancellationToken).ConfigureAwait(false); } if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) @@ -146,22 +122,22 @@ namespace MQTTnet.Adapter return packet; } - private static async Task ReceiveAsync(Stream stream, CancellationToken cancellationToken) + private static async Task ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) { - var header = await MqttPacketReader.ReadHeaderAsync(stream, cancellationToken).ConfigureAwait(false); + var header = await MqttPacketReader.ReadHeaderAsync(channel, cancellationToken).ConfigureAwait(false); if (header == null) { return null; } - + if (header.BodyLength == 0) { return new ReceivedMqttPacket(header, new MemoryStream(new byte[0], false)); } - var body = header.BodyLength <= ReadBufferSize ? new MemoryStream(header.BodyLength) : new MemoryStream(); + var body = new MemoryStream(header.BodyLength); - var buffer = new byte[ReadBufferSize]; + var buffer = new byte[Math.Min(ReadBufferSize, header.BodyLength)]; while (body.Length < header.BodyLength) { var bytesLeft = header.BodyLength - (int)body.Length; @@ -170,7 +146,7 @@ namespace MQTTnet.Adapter bytesLeft = buffer.Length; } - var readBytesCount = await stream.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); + var readBytesCount = await channel.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); // Check if the client closed the connection before sending the full body. if (readBytesCount == 0) @@ -240,7 +216,7 @@ namespace MQTTnet.Adapter public void Dispose() { _isDisposed = true; - _semaphore?.Dispose(); + _channel?.Dispose(); } diff --git a/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs b/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs index b332884..dd9d1f2 100644 --- a/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs @@ -1,15 +1,15 @@ using System; -using System.IO; +using System.Threading; using System.Threading.Tasks; namespace MQTTnet.Channel { public interface IMqttChannel : IDisposable { - Stream SendStream { get; } - Stream ReceiveStream { get; } - - Task ConnectAsync(); + Task ConnectAsync(CancellationToken cancellationToken); Task DisconnectAsync(); + + Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); + Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); } } diff --git a/Frameworks/MQTTnet.NetStandard/Client/IMqttClientOptions.cs b/Frameworks/MQTTnet.NetStandard/Client/IMqttClientOptions.cs index 647f688..aa873af 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/IMqttClientOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/IMqttClientOptions.cs @@ -6,17 +6,16 @@ namespace MQTTnet.Client public interface IMqttClientOptions { string ClientId { get; } - - IMqttClientCredentials Credentials { get; } bool CleanSession { get; } - MqttApplicationMessage WillMessage { get; } + IMqttClientCredentials Credentials { get; } + MqttProtocolVersion ProtocolVersion { get; } + IMqttClientChannelOptions ChannelOptions { get; } TimeSpan CommunicationTimeout { get; } TimeSpan KeepAlivePeriod { get; } TimeSpan? KeepAliveSendInterval { get; } + MqttReceivedApplicationMessageProcessingMode ReceivedApplicationMessageProcessingMode { get; } - MqttProtocolVersion ProtocolVersion { get; } - - IMqttClientChannelOptions ChannelOptions { get; } + MqttApplicationMessage WillMessage { get; } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs index 90790d5..edf9b44 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs @@ -18,12 +18,12 @@ namespace MQTTnet.Client private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly Stopwatch _sendTracker = new Stopwatch(); private readonly SemaphoreSlim _disconnectLock = new SemaphoreSlim(1, 1); + private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); + private readonly IMqttClientAdapterFactory _adapterFactory; - private readonly MqttPacketDispatcher _packetDispatcher; private readonly IMqttNetLogger _logger; private IMqttClientOptions _options; - private bool _isReceivingPackets; private CancellationTokenSource _cancellationTokenSource; private Task _packetReceiverTask; private Task _keepAliveMessageSenderTask; @@ -33,8 +33,6 @@ namespace MQTTnet.Client { _adapterFactory = channelFactory ?? throw new ArgumentNullException(nameof(channelFactory)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - - _packetDispatcher = new MqttPacketDispatcher(logger); } public event EventHandler Connected; @@ -52,27 +50,27 @@ namespace MQTTnet.Client try { - _options = options; _cancellationTokenSource = new CancellationTokenSource(); + _options = options; _packetIdentifierProvider.Reset(); _packetDispatcher.Reset(); _adapter = _adapterFactory.CreateClientAdapter(options, _logger); _logger.Verbose("Trying to connect with server."); - await _adapter.ConnectAsync(_options.CommunicationTimeout).ConfigureAwait(false); + await _adapter.ConnectAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token).ConfigureAwait(false); _logger.Verbose("Connection with server established."); - await StartReceivingPacketsAsync().ConfigureAwait(false); + StartReceivingPackets(_cancellationTokenSource.Token); - var connectResponse = await AuthenticateAsync(options.WillMessage).ConfigureAwait(false); + var connectResponse = await AuthenticateAsync(options.WillMessage, _cancellationTokenSource.Token).ConfigureAwait(false); _logger.Verbose("MQTT connection with server established."); _sendTracker.Restart(); if (_options.KeepAlivePeriod != TimeSpan.Zero) { - StartSendingKeepAliveMessages(); + StartSendingKeepAliveMessages(_cancellationTokenSource.Token); } IsConnected = true; @@ -92,16 +90,11 @@ namespace MQTTnet.Client public async Task DisconnectAsync() { - if (!IsConnected) - { - return; - } - try { - if (!_cancellationTokenSource.IsCancellationRequested) + if (IsConnected && !_cancellationTokenSource.IsCancellationRequested) { - await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); + await SendAsync(new MqttDisconnectPacket(), _cancellationTokenSource.Token).ConfigureAwait(false); } } finally @@ -122,7 +115,7 @@ namespace MQTTnet.Client TopicFilters = topicFilters.ToList() }; - var response = await SendAndReceiveAsync(subscribePacket).ConfigureAwait(false); + var response = await SendAndReceiveAsync(subscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); if (response.SubscribeReturnCodes.Count != subscribePacket.TopicFilters.Count) { @@ -144,7 +137,7 @@ namespace MQTTnet.Client TopicFilters = topicFilters.ToList() }; - await SendAndReceiveAsync(unsubscribePacket).ConfigureAwait(false); + await SendAndReceiveAsync(unsubscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); } public async Task PublishAsync(IEnumerable applicationMessages) @@ -161,7 +154,7 @@ namespace MQTTnet.Client case MqttQualityOfServiceLevel.AtMostOnce: { // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] - await SendAsync(qosGroup.Cast().ToArray()).ConfigureAwait(false); + await SendAsync(qosGroup, _cancellationTokenSource.Token).ConfigureAwait(false); break; } case MqttQualityOfServiceLevel.AtLeastOnce: @@ -169,7 +162,7 @@ namespace MQTTnet.Client foreach (var publishPacket in qosGroup) { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); - await SendAndReceiveAsync(publishPacket).ConfigureAwait(false); + await SendAndReceiveAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); } break; @@ -180,13 +173,13 @@ namespace MQTTnet.Client { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); - var pubRecPacket = await SendAndReceiveAsync(publishPacket).ConfigureAwait(false); + var pubRecPacket = await SendAndReceiveAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); var pubRelPacket = new MqttPubRelPacket { PacketIdentifier = pubRecPacket.PacketIdentifier }; - await SendAndReceiveAsync(pubRelPacket).ConfigureAwait(false); + await SendAndReceiveAsync(pubRelPacket, _cancellationTokenSource.Token).ConfigureAwait(false); } break; @@ -207,7 +200,7 @@ namespace MQTTnet.Client _adapter?.Dispose(); } - private async Task AuthenticateAsync(MqttApplicationMessage willApplicationMessage) + private async Task AuthenticateAsync(MqttApplicationMessage willApplicationMessage, CancellationToken cancellationToken) { var connectPacket = new MqttConnectPacket { @@ -219,7 +212,7 @@ namespace MQTTnet.Client WillMessage = willApplicationMessage }; - var response = await SendAndReceiveAsync(connectPacket).ConfigureAwait(false); + var response = await SendAndReceiveAsync(connectPacket, cancellationToken).ConfigureAwait(false); if (response.ConnectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { throw new MqttConnectingFailedException(response.ConnectReturnCode); @@ -264,21 +257,19 @@ namespace MQTTnet.Client try { - if (_packetReceiverTask != null && _packetReceiverTask != sender) - { - _packetReceiverTask.Wait(); - } + await WaitForTaskAsync(_packetReceiverTask, sender).ConfigureAwait(false); + await WaitForTaskAsync(_keepAliveMessageSenderTask, sender).ConfigureAwait(false); if (_keepAliveMessageSenderTask != null && _keepAliveMessageSenderTask != sender) { - _keepAliveMessageSenderTask.Wait(); + await _keepAliveMessageSenderTask.ConfigureAwait(false); } if (_adapter != null) { - await _adapter.DisconnectAsync(_options.CommunicationTimeout).ConfigureAwait(false); + await _adapter.DisconnectAsync(_options.CommunicationTimeout, CancellationToken.None).ConfigureAwait(false); } - + _logger.Verbose("Disconnected from adapter."); } catch (Exception adapterException) @@ -297,121 +288,63 @@ namespace MQTTnet.Client } } - private async Task ProcessReceivedPacketAsync(MqttBasePacket packet) + private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - try - { - if (packet is MqttPublishPacket publishPacket) - { - await ProcessReceivedPublishPacketAsync(publishPacket).ConfigureAwait(false); - return; - } - - if (packet is MqttPingReqPacket) - { - await SendAsync(new MqttPingRespPacket()).ConfigureAwait(false); - return; - } - - if (packet is MqttDisconnectPacket) - { - await DisconnectAsync().ConfigureAwait(false); - return; - } - - if (packet is MqttPubRelPacket pubRelPacket) - { - await ProcessReceivedPubRelPacket(pubRelPacket).ConfigureAwait(false); - return; - } - - _packetDispatcher.Dispatch(packet); - } - catch (Exception exception) - { - _logger.Error(exception, "Unhandled exception while processing received packet."); - } + return SendAsync(new[] { packet }, cancellationToken); } - private void FireApplicationMessageReceivedEvent(MqttPublishPacket publishPacket) + private Task SendAsync(IEnumerable packets, CancellationToken cancellationToken) { - try - { - var applicationMessage = publishPacket.ToApplicationMessage(); - ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(_options.ClientId, applicationMessage)); - } - catch (Exception exception) + if (cancellationToken.IsCancellationRequested) { - _logger.Error(exception, "Unhandled exception while handling application message."); + throw new TaskCanceledException(); } + + _sendTracker.Restart(); + return _adapter.SendPacketsAsync(_options.CommunicationTimeout, packets, cancellationToken); } - private Task ProcessReceivedPublishPacketAsync(MqttPublishPacket publishPacket) + private async Task SendAndReceiveAsync(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket { - if (_cancellationTokenSource.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) { - return Task.FromResult(0); + throw new TaskCanceledException(); } - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) - { - FireApplicationMessageReceivedEvent(publishPacket); - return Task.FromResult(0); - } + _sendTracker.Restart(); - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) + ushort identifier = 0; + if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier.HasValue) { - FireApplicationMessageReceivedEvent(publishPacket); - return SendAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + identifier = packetWithIdentifier.PacketIdentifier.Value; } - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) + var packetAwaiter = _packetDispatcher.AddPacketAwaiter(identifier); + try { - // QoS 2 is implement as method "B" [4.3.3 QoS 2: Exactly once delivery] - FireApplicationMessageReceivedEvent(publishPacket); - return SendAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); - } + await _adapter.SendPacketsAsync(_options.CommunicationTimeout, new[] { requestPacket }, cancellationToken).ConfigureAwait(false); + var respone = await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); - throw new MqttCommunicationException("Received a not supported QoS level."); - } - - private Task ProcessReceivedPubRelPacket(MqttPubRelPacket pubRelPacket) - { - var response = new MqttPubCompPacket + return (TResponsePacket)respone; + } + catch (MqttCommunicationTimedOutException) { - PacketIdentifier = pubRelPacket.PacketIdentifier - }; - - return SendAsync(response); - } - - private Task SendAsync(params MqttBasePacket[] packets) - { - _sendTracker.Restart(); - return _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, packets); - } - - private async Task SendAndReceiveAsync(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket - { - ushort? identifier = null; - if (requestPacket is IMqttPacketWithIdentifier requestPacketWithIdentifier) + _logger.Warning($"Timeout while waiting for packet of type '{typeof(TResponsePacket).Namespace}'."); + throw; + } + finally { - identifier = requestPacketWithIdentifier.PacketIdentifier; + _packetDispatcher.RemovePacketAwaiter(identifier); } - - var packetAwaiter = _packetDispatcher.WaitForPacketAsync(typeof(TResponsePacket), identifier, _options.CommunicationTimeout); - await SendAsync(requestPacket).ConfigureAwait(false); - - return (TResponsePacket)await packetAwaiter.ConfigureAwait(false); } - private async Task SendKeepAliveMessagesAsync() + private async Task SendKeepAliveMessagesAsync(CancellationToken cancellationToken) { _logger.Verbose("Start sending keep alive packets."); try { - while (!_cancellationTokenSource.Token.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { var keepAliveSendInterval = TimeSpan.FromSeconds(_options.KeepAlivePeriod.TotalSeconds * 0.75); if (_options.KeepAliveSendInterval.HasValue) @@ -421,10 +354,10 @@ namespace MQTTnet.Client if (_sendTracker.Elapsed > keepAliveSendInterval) { - await SendAndReceiveAsync(new MqttPingReqPacket()).ConfigureAwait(false); + await SendAndReceiveAsync(new MqttPingReqPacket(), cancellationToken).ConfigureAwait(false); } - await Task.Delay(keepAliveSendInterval, _cancellationTokenSource.Token).ConfigureAwait(false); + await Task.Delay(keepAliveSendInterval, cancellationToken).ConfigureAwait(false); } } catch (Exception exception) @@ -440,7 +373,7 @@ namespace MQTTnet.Client { _logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets."); } - + await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); } finally @@ -449,24 +382,34 @@ namespace MQTTnet.Client } } - private async Task ReceivePacketsAsync() + private async Task ReceivePacketsAsync(CancellationToken cancellationToken) { _logger.Verbose("Start receiving packets."); try { - while (!_cancellationTokenSource.Token.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { - _isReceivingPackets = true; - - var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false); + var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); - if (_cancellationTokenSource.Token.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) { return; } - StartProcessReceivedPacket(packet); + if (packet == null) + { + continue; + } + + if (_options.ReceivedApplicationMessageProcessingMode == MqttReceivedApplicationMessageProcessingMode.SingleThread) + { + await ProcessReceivedPacketAsync(packet, cancellationToken).ConfigureAwait(false); + } + else if (_options.ReceivedApplicationMessageProcessingMode == MqttReceivedApplicationMessageProcessingMode.DedicatedThread) + { + StartProcessReceivedPacketAsync(packet, cancellationToken); + } } } catch (Exception exception) @@ -484,6 +427,7 @@ namespace MQTTnet.Client } await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); + _packetDispatcher.Dispatch(exception); } finally { @@ -491,26 +435,133 @@ namespace MQTTnet.Client } } - private void StartProcessReceivedPacket(MqttBasePacket packet) + private async Task ProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) + { + try + { + if (packet is MqttPublishPacket publishPacket) + { + await ProcessReceivedPublishPacketAsync(publishPacket, cancellationToken).ConfigureAwait(false); + return; + } + + if (packet is MqttPingReqPacket) + { + await SendAsync(new MqttPingRespPacket(), cancellationToken).ConfigureAwait(false); + return; + } + + if (packet is MqttDisconnectPacket) + { + await DisconnectAsync().ConfigureAwait(false); + return; + } + + if (packet is MqttPubRelPacket pubRelPacket) + { + await ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken).ConfigureAwait(false); + return; + } + + _packetDispatcher.Dispatch(packet); + } + catch (Exception exception) + { + _logger.Error(exception, "Unhandled exception while processing received packet."); + } + } + + private Task ProcessReceivedPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) + { + if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) + { + FireApplicationMessageReceivedEvent(publishPacket); + return Task.FromResult(0); + } + + if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) + { + FireApplicationMessageReceivedEvent(publishPacket); + return SendAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }, cancellationToken); + } + + if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) + { + // QoS 2 is implement as method "B" [4.3.3 QoS 2: Exactly once delivery] + FireApplicationMessageReceivedEvent(publishPacket); + return SendAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }, cancellationToken); + } + + throw new MqttCommunicationException("Received a not supported QoS level."); + } + + private Task ProcessReceivedPubRelPacket(MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { - Task.Run(() => ProcessReceivedPacketAsync(packet), _cancellationTokenSource.Token); + var response = new MqttPubCompPacket + { + PacketIdentifier = pubRelPacket.PacketIdentifier + }; + + return SendAsync(response, cancellationToken); } - private async Task StartReceivingPacketsAsync() + private void StartReceivingPackets(CancellationToken cancellationToken) { - _isReceivingPackets = false; + _packetReceiverTask = Task.Factory.StartNew( + () => ReceivePacketsAsync(cancellationToken), + cancellationToken, + TaskCreationOptions.LongRunning, + TaskScheduler.Current); + } - _packetReceiverTask = Task.Run(ReceivePacketsAsync, _cancellationTokenSource.Token); + private void StartSendingKeepAliveMessages(CancellationToken cancellationToken) + { + _keepAliveMessageSenderTask = Task.Factory.StartNew( + () => SendKeepAliveMessagesAsync(cancellationToken), + cancellationToken, + TaskCreationOptions.LongRunning, + TaskScheduler.Current); + } + + private void StartProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + Task.Run(() => ProcessReceivedPacketAsync(packet, cancellationToken), cancellationToken); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } - while (!_isReceivingPackets && !_cancellationTokenSource.Token.IsCancellationRequested) + private void FireApplicationMessageReceivedEvent(MqttPublishPacket publishPacket) + { + try { - await Task.Delay(TimeSpan.FromMilliseconds(100), _cancellationTokenSource.Token).ConfigureAwait(false); + var applicationMessage = publishPacket.ToApplicationMessage(); + ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(_options.ClientId, applicationMessage)); + } + catch (Exception exception) + { + _logger.Error(exception, "Unhandled exception while handling application message."); } } - private void StartSendingKeepAliveMessages() + private static async Task WaitForTaskAsync(Task task, Task sender) { - _keepAliveMessageSenderTask = Task.Run(SendKeepAliveMessagesAsync, _cancellationTokenSource.Token); + if (task == sender || task == null) + { + return; + } + + if (task.IsCanceled || task.IsCompleted || task.IsFaulted) + { + return; + } + + try + { + await task.ConfigureAwait(false); + } + catch (TaskCanceledException) + { + } } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptions.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptions.cs index 9640a47..ae0517a 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptions.cs @@ -5,22 +5,17 @@ namespace MQTTnet.Client { public class MqttClientOptions : IMqttClientOptions { - public MqttApplicationMessage WillMessage { get; set; } - public string ClientId { get; set; } = Guid.NewGuid().ToString("N"); - public bool CleanSession { get; set; } = true; - public IMqttClientCredentials Credentials { get; set; } = new MqttClientCredentials(); + public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; + public IMqttClientChannelOptions ChannelOptions { get; set; } + public TimeSpan CommunicationTimeout { get; set; } = TimeSpan.FromSeconds(10); public TimeSpan KeepAlivePeriod { get; set; } = TimeSpan.FromSeconds(15); - public TimeSpan? KeepAliveSendInterval { get; set; } + public MqttReceivedApplicationMessageProcessingMode ReceivedApplicationMessageProcessingMode { get; set; } = MqttReceivedApplicationMessageProcessingMode.SingleThread; - public TimeSpan CommunicationTimeout { get; set; } = TimeSpan.FromSeconds(10); - - public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; - - public IMqttClientChannelOptions ChannelOptions { get; set; } + public MqttApplicationMessage WillMessage { get; set; } } } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptionsBuilder.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptionsBuilder.cs index ad766c3..2eefb81 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptionsBuilder.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClientOptionsBuilder.cs @@ -36,6 +36,12 @@ namespace MQTTnet.Client return this; } + public MqttClientOptionsBuilder WithKeepAliveSendInterval(TimeSpan value) + { + _options.KeepAliveSendInterval = value; + return this; + } + public MqttClientOptionsBuilder WithClientId(string value) { _options.ClientId = value; @@ -108,6 +114,13 @@ namespace MQTTnet.Client return this; } + public MqttClientOptionsBuilder WithReceivedApplicationMessageProcessingMode( + MqttReceivedApplicationMessageProcessingMode mode) + { + _options.ReceivedApplicationMessageProcessingMode = mode; + return this; + } + public IMqttClientOptions Build() { if (_tlsOptions != null) diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs index ac96ce6..44bc5b4 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs @@ -6,7 +6,7 @@ public int? Port { get; set; } - public int BufferSize { get; set; } = 20 * 4096; + public int BufferSize { get; set; } = 4096; public MqttClientTlsOptions TlsOptions { get; set; } = new MqttClientTlsOptions(); } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs index 114283d..9718298 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs @@ -13,6 +13,8 @@ namespace MQTTnet.Client public CookieContainer CookieContainer { get; set; } + public int BufferSize { get; set; } = 4096; + public MqttClientTlsOptions TlsOptions { get; set; } = new MqttClientTlsOptions(); } } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs index 8b09fbc..1a1021c 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs @@ -1,38 +1,19 @@ using System; using System.Collections.Concurrent; using System.Threading.Tasks; -using MQTTnet.Diagnostics; -using MQTTnet.Exceptions; -using MQTTnet.Internal; using MQTTnet.Packets; namespace MQTTnet.Client { public class MqttPacketDispatcher { - private readonly ConcurrentDictionary>> _awaiters = new ConcurrentDictionary>>(); - private readonly IMqttNetLogger _logger; - - public MqttPacketDispatcher(IMqttNetLogger logger) - { - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - } - - public async Task WaitForPacketAsync(Type responseType, ushort? identifier, TimeSpan timeout) + private readonly ConcurrentDictionary, TaskCompletionSource> _awaiters = new ConcurrentDictionary, TaskCompletionSource>(); + + public void Dispatch(Exception exception) { - var packetAwaiter = AddPacketAwaiter(responseType, identifier); - try - { - return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false); - } - catch (MqttCommunicationTimedOutException) - { - _logger.Warning("Timeout while waiting for packet of type '{0}'.", responseType.Name); - throw; - } - finally + foreach (var awaiter in _awaiters) { - RemovePacketAwaiter(responseType, identifier); + awaiter.Value.SetException(exception); } } @@ -40,21 +21,19 @@ namespace MQTTnet.Client { if (packet == null) throw new ArgumentNullException(nameof(packet)); - var type = packet.GetType(); - - if (_awaiters.TryGetValue(type, out var byId)) + ushort identifier = 0; + if (packet is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier.HasValue) { - ushort? identifier = 0; - if (packet is IMqttPacketWithIdentifier packetWithIdentifier) - { - identifier = packetWithIdentifier.PacketIdentifier; - } + identifier = packetWithIdentifier.PacketIdentifier.Value; + } - if (byId.TryRemove(identifier.Value, out var tcs)) - { - tcs.TrySetResult(packet); - return; - } + var type = packet.GetType(); + var key = new Tuple(identifier, type); + + if (_awaiters.TryRemove(key, out var tcs)) + { + tcs.TrySetResult(packet); + return; } throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched."); @@ -65,7 +44,7 @@ namespace MQTTnet.Client _awaiters.Clear(); } - private TaskCompletionSource AddPacketAwaiter(Type responseType, ushort? identifier) + public TaskCompletionSource AddPacketAwaiter(ushort? identifier) where TResponsePacket : MqttBasePacket { var tcs = new TaskCompletionSource(); @@ -73,25 +52,25 @@ namespace MQTTnet.Client { identifier = 0; } - - var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary>()); - if (!byId.TryAdd(identifier.Value, tcs)) + + var key = new Tuple(identifier ?? 0, typeof(TResponsePacket)); + if (!_awaiters.TryAdd(key, tcs)) { - throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}."); + throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{key.Item2.Name}' with identifier {key.Item1}."); } return tcs; } - private void RemovePacketAwaiter(Type responseType, ushort? identifier) + public void RemovePacketAwaiter(ushort? identifier) where TResponsePacket : MqttBasePacket { if (!identifier.HasValue) { identifier = 0; } - var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary>()); - byId.TryRemove(identifier.Value, out var _); + var key = new Tuple(identifier ?? 0, typeof(TResponsePacket)); + _awaiters.TryRemove(key, out var _); } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttReceivedApplicationMessageProcessingMode.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttReceivedApplicationMessageProcessingMode.cs new file mode 100644 index 0000000..651abe5 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttReceivedApplicationMessageProcessingMode.cs @@ -0,0 +1,8 @@ +namespace MQTTnet.Client +{ + public enum MqttReceivedApplicationMessageProcessingMode + { + SingleThread, + DedicatedThread + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs b/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs index 7d0adcd..86f58b3 100644 --- a/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs +++ b/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs @@ -1,6 +1,11 @@ -namespace MQTTnet.Exceptions +using System; + +namespace MQTTnet.Exceptions { public sealed class MqttCommunicationTimedOutException : MqttCommunicationException { + public MqttCommunicationTimedOutException() { } + public MqttCommunicationTimedOutException(Exception innerException) : base(innerException) { } + } } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs index a9cd18e..19d5824 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.InteropServices.WindowsRuntime; +using System.Threading; using System.Threading.Tasks; using Windows.Networking; using Windows.Networking.Sockets; @@ -17,17 +18,20 @@ namespace MQTTnet.Implementations { // ReSharper disable once MemberCanBePrivate.Global // ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global - public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. + public static int BufferSize { get; set; } = 4096; // Can be changed for fine tuning by library user. private readonly int _bufferSize = BufferSize; private readonly MqttClientTcpOptions _options; private StreamSocket _socket; + private Stream _readStream; + private Stream _writeStream; public MqttTcpChannel(MqttClientTcpOptions options) { _options = options ?? throw new ArgumentNullException(nameof(options)); - _bufferSize = _options.BufferSize; + + _bufferSize = options.BufferSize; } public MqttTcpChannel(StreamSocket socket) @@ -37,16 +41,15 @@ namespace MQTTnet.Implementations CreateStreams(); } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } - public static Func> CustomIgnorableServerCertificateErrorsResolver { get; set; } - public async Task ConnectAsync() + public async Task ConnectAsync(CancellationToken cancellationToken) { if (_socket == null) { _socket = new StreamSocket(); + _socket.Control.NoDelay = true; + _socket.Control.KeepAlive = true; } if (!_options.TlsOptions.UseTls) @@ -74,11 +77,22 @@ namespace MQTTnet.Implementations return Task.FromResult(0); } + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _readStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _writeStream.WriteAsync(buffer, offset, count, cancellationToken); + await _writeStream.FlushAsync(cancellationToken); + } + public void Dispose() { try { - SendStream?.Dispose(); + _readStream?.Dispose(); } catch (ObjectDisposedException) { @@ -88,12 +102,12 @@ namespace MQTTnet.Implementations } finally { - SendStream = null; + _readStream = null; } try { - ReceiveStream?.Dispose(); + _writeStream?.Dispose(); } catch (ObjectDisposedException) { @@ -103,7 +117,7 @@ namespace MQTTnet.Implementations } finally { - ReceiveStream = null; + _writeStream = null; } try @@ -122,12 +136,6 @@ namespace MQTTnet.Implementations } } - private void CreateStreams() - { - SendStream = _socket.OutputStream.AsStreamForWrite(_bufferSize); - ReceiveStream = _socket.InputStream.AsStreamForRead(_bufferSize); - } - private static Certificate LoadCertificate(MqttClientTcpOptions options) { if (options.TlsOptions.Certificates == null || !options.TlsOptions.Certificates.Any()) @@ -171,6 +179,12 @@ namespace MQTTnet.Implementations return result; } + + private void CreateStreams() + { + _readStream = _socket.InputStream.AsStreamForRead(_bufferSize); + _writeStream = _socket.OutputStream.AsStreamForWrite(_bufferSize); + } } } #endif \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index 8d43d87..fe01781 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -7,6 +7,7 @@ using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using System.IO; using System.Linq; +using System.Threading; using MQTTnet.Channel; using MQTTnet.Client; @@ -14,20 +15,10 @@ namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttChannel { -#if NET452 || NET461 - // ReSharper disable once MemberCanBePrivate.Global - // ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global - public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. - - private readonly int _bufferSize = BufferSize; -#else - private readonly int _bufferSize = 0; -#endif - private readonly MqttClientTcpOptions _options; private Socket _socket; - private SslStream _sslStream; + private Stream _stream; /// /// called on client sockets are created in connect @@ -35,7 +26,6 @@ namespace MQTTnet.Implementations public MqttTcpChannel(MqttClientTcpOptions options) { _options = options ?? throw new ArgumentNullException(nameof(options)); - _bufferSize = options.BufferSize; } /// @@ -45,21 +35,17 @@ namespace MQTTnet.Implementations public MqttTcpChannel(Socket socket, SslStream sslStream) { _socket = socket ?? throw new ArgumentNullException(nameof(socket)); - _sslStream = sslStream; - CreateStreams(); + CreateStream(sslStream); } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } - public static Func CustomCertificateValidationCallback { get; set; } - public async Task ConnectAsync() + public async Task ConnectAsync(CancellationToken cancellationToken) { if (_socket == null) { - _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + _socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; } #if NET452 || NET461 @@ -68,13 +54,14 @@ namespace MQTTnet.Implementations await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); #endif + SslStream sslStream = null; if (_options.TlsOptions.UseTls) { - _sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); - await _sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); + sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); + await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); } - - CreateStreams(); + + CreateStream(sslStream); } public Task DisconnectAsync() @@ -83,46 +70,21 @@ namespace MQTTnet.Implementations return Task.FromResult(0); } - public void Dispose() + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var oneStreamIsUsed = SendStream != null && ReceiveStream != null && ReferenceEquals(SendStream, ReceiveStream); - - try - { - SendStream?.Dispose(); - } - catch (ObjectDisposedException) - { - } - catch (NullReferenceException) - { - } - finally - { - SendStream = null; - } + return _stream.ReadAsync(buffer, offset, count, cancellationToken); + } - try - { - if (!oneStreamIsUsed) - { - ReceiveStream?.Dispose(); - } - } - catch (ObjectDisposedException) - { - } - catch (NullReferenceException) - { - } - finally - { - ReceiveStream = null; - } + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _stream.WriteAsync(buffer, offset, count, cancellationToken); + } + public void Dispose() + { try { - _sslStream?.Dispose(); + _stream?.Dispose(); } catch (ObjectDisposedException) { @@ -132,7 +94,7 @@ namespace MQTTnet.Implementations } finally { - _sslStream = null; + _stream = null; } try @@ -198,25 +160,16 @@ namespace MQTTnet.Implementations return certificates; } - private void CreateStreams() + private void CreateStream(Stream stream) { - Stream stream; - if (_sslStream != null) + if (stream != null) { - stream = _sslStream; + _stream = stream; } else { - stream = new NetworkStream(_socket, true); + _stream = new NetworkStream(_socket, true); } - -#if NET452 || NET461 - SendStream = new BufferedStream(stream, _bufferSize); - ReceiveStream = new BufferedStream(stream, _bufferSize); -#else - SendStream = stream; - ReceiveStream = stream; -#endif } } } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs index e71d6ff..b2c3110 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs @@ -9,7 +9,7 @@ using MQTTnet.Server; namespace MQTTnet.Implementations { - public class MqttTcpServerAdapter : IMqttServerAdapter, IDisposable + public class MqttTcpServerAdapter : IMqttServerAdapter { private readonly IMqttNetLogger _logger; private StreamSocketListener _defaultEndpointSocket; @@ -26,12 +26,19 @@ namespace MQTTnet.Implementations if (options == null) throw new ArgumentNullException(nameof(options)); if (_defaultEndpointSocket != null) throw new InvalidOperationException("Server is already started."); - + if (options.DefaultEndpointOptions.IsEnabled) { _defaultEndpointSocket = new StreamSocketListener(); - await _defaultEndpointSocket.BindServiceNameAsync(options.GetDefaultEndpointPort().ToString(), SocketProtectionLevel.PlainSocket); + + // This also affects the client sockets. + _defaultEndpointSocket.Control.NoDelay = true; + _defaultEndpointSocket.Control.KeepAlive = true; + _defaultEndpointSocket.Control.QualityOfService = SocketQualityOfService.LowLatency; _defaultEndpointSocket.ConnectionReceived += AcceptDefaultEndpointConnectionsAsync; + + await _defaultEndpointSocket.BindServiceNameAsync(options.GetDefaultEndpointPort().ToString(), SocketProtectionLevel.PlainSocket); + } if (options.TlsEndpointOptions.IsEnabled) @@ -55,7 +62,7 @@ namespace MQTTnet.Implementations public void Dispose() { - StopAsync(); + StopAsync().GetAwaiter().GetResult(); } private void AcceptDefaultEndpointConnectionsAsync(StreamSocketListener sender, StreamSocketListenerConnectionReceivedEventArgs args) diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs index dc33a3f..0404de3 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs @@ -14,7 +14,7 @@ using MQTTnet.Server; namespace MQTTnet.Implementations { - public class MqttTcpServerAdapter : IMqttServerAdapter, IDisposable + public class MqttTcpServerAdapter : IMqttServerAdapter { private readonly IMqttNetLogger _logger; @@ -38,11 +38,16 @@ namespace MQTTnet.Implementations if (options.DefaultEndpointOptions.IsEnabled) { - _defaultEndpointSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); + _defaultEndpointSocket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; + _defaultEndpointSocket.Bind(new IPEndPoint(options.DefaultEndpointOptions.BoundIPAddress, options.GetDefaultEndpointPort())); _defaultEndpointSocket.Listen(options.ConnectionBacklog); - Task.Run(async () => await AcceptDefaultEndpointConnectionsAsync(_cancellationTokenSource.Token).ConfigureAwait(false), _cancellationTokenSource.Token).ConfigureAwait(false); + Task.Factory.StartNew( + () => AcceptDefaultEndpointConnectionsAsync(_cancellationTokenSource.Token), + _cancellationTokenSource.Token, + TaskCreationOptions.LongRunning, + TaskScheduler.Current); } if (options.TlsEndpointOptions.IsEnabled) @@ -62,7 +67,11 @@ namespace MQTTnet.Implementations _tlsEndpointSocket.Bind(new IPEndPoint(options.TlsEndpointOptions.BoundIPAddress, options.GetTlsEndpointPort())); _tlsEndpointSocket.Listen(options.ConnectionBacklog); - Task.Run(async () => await AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token).ConfigureAwait(false), _cancellationTokenSource.Token).ConfigureAwait(false); + Task.Factory.StartNew( + () => AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token), + _cancellationTokenSource.Token, + TaskCreationOptions.LongRunning, + TaskScheduler.Current); } return Task.FromResult(0); @@ -78,7 +87,7 @@ namespace MQTTnet.Implementations _defaultEndpointSocket = null; _tlsCertificate = null; - + _tlsEndpointSocket?.Dispose(); _tlsEndpointSocket = null; @@ -87,7 +96,7 @@ namespace MQTTnet.Implementations public void Dispose() { - StopAsync(); + StopAsync().GetAwaiter().GetResult(); } private async Task AcceptDefaultEndpointConnectionsAsync(CancellationToken cancellationToken) @@ -102,6 +111,7 @@ namespace MQTTnet.Implementations #else var clientSocket = await _defaultEndpointSocket.AcceptAsync().ConfigureAwait(false); #endif + clientSocket.NoDelay = true; var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger); ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); @@ -137,7 +147,7 @@ namespace MQTTnet.Implementations var sslStream = new SslStream(new NetworkStream(clientSocket)); await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, SslProtocols.Tls12, false).ConfigureAwait(false); - + var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, sslStream), new MqttPacketSerializer(), _logger); ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs index 057a07e..5ceecbb 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs @@ -1,5 +1,4 @@ using System; -using System.IO; using System.Net.WebSockets; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -11,22 +10,22 @@ namespace MQTTnet.Implementations { public sealed class MqttWebSocketChannel : IMqttChannel { - // ReSharper disable once MemberCanBePrivate.Global - // ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global - public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. - + private readonly SemaphoreSlim _sendLock = new SemaphoreSlim(1, 1); private readonly MqttClientWebSocketOptions _options; - private ClientWebSocket _webSocket; + + private WebSocket _webSocket; public MqttWebSocketChannel(MqttClientWebSocketOptions options) { _options = options ?? throw new ArgumentNullException(nameof(options)); } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } + public MqttWebSocketChannel(WebSocket webSocket) + { + _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); + } - public async Task ConnectAsync() + public async Task ConnectAsync(CancellationToken cancellationToken) { var uri = _options.Uri; if (!uri.StartsWith("ws://", StringComparison.OrdinalIgnoreCase) && !uri.StartsWith("wss://", StringComparison.OrdinalIgnoreCase)) @@ -41,42 +40,40 @@ namespace MQTTnet.Implementations } } - _webSocket = new ClientWebSocket(); - + var clientWebSocket = new ClientWebSocket(); + if (_options.RequestHeaders != null) { foreach (var requestHeader in _options.RequestHeaders) { - _webSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); - } + clientWebSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); + } } if (_options.SubProtocols != null) { foreach (var subProtocol in _options.SubProtocols) { - _webSocket.Options.AddSubProtocol(subProtocol); + clientWebSocket.Options.AddSubProtocol(subProtocol); } } if (_options.CookieContainer != null) { - _webSocket.Options.Cookies = _options.CookieContainer; + clientWebSocket.Options.Cookies = _options.CookieContainer; } if (_options.TlsOptions?.UseTls == true && _options.TlsOptions?.Certificates != null) { - _webSocket.Options.ClientCertificates = new X509CertificateCollection(); + clientWebSocket.Options.ClientCertificates = new X509CertificateCollection(); foreach (var certificate in _options.TlsOptions.Certificates) { - _webSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); + clientWebSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); } } - await _webSocket.ConnectAsync(new Uri(uri), CancellationToken.None).ConfigureAwait(false); - - SendStream = new WebSocketStream(_webSocket); - ReceiveStream = SendStream; + await clientWebSocket.ConnectAsync(new Uri(uri), cancellationToken).ConfigureAwait(false); + _webSocket = clientWebSocket; } public async Task DisconnectAsync() @@ -94,8 +91,32 @@ namespace MQTTnet.Implementations Dispose(); } + public async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, offset, count), cancellationToken).ConfigureAwait(false); + return response.Count; + } + + public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + // This lock is required because the client will throw an exception if _SendAsync_ is + // called from multiple threads at the same time. But this issue only happens with several + // framework versions. + await _sendLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + await _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken).ConfigureAwait(false); + } + finally + { + _sendLock.Release(); + } + } + public void Dispose() { + _sendLock?.Dispose(); + try { _webSocket?.Dispose(); @@ -106,7 +127,7 @@ namespace MQTTnet.Implementations finally { _webSocket = null; - } + } } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs b/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs deleted file mode 100644 index bad0393..0000000 --- a/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs +++ /dev/null @@ -1,137 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Net.WebSockets; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Exceptions; - -namespace MQTTnet.Implementations -{ - public class WebSocketStream : Stream - { - private readonly byte[] _chunckBuffer = new byte[MqttWebSocketChannel.BufferSize]; - private readonly Queue _buffer = new Queue(MqttWebSocketChannel.BufferSize); - private readonly WebSocket _webSocket; - - public WebSocketStream(WebSocket webSocket) - { - _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); - } - - public override bool CanRead => true; - - public override bool CanSeek => false; - - public override bool CanWrite => true; - - public override long Length => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - - public override void Flush() - { - } - - public override Task FlushAsync(CancellationToken cancellationToken) - { - return Task.FromResult(0); - } - - public override int Read(byte[] buffer, int offset, int count) - { - return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - var bytesRead = 0; - - // Use existing date from buffer. - while (count > 0 && _buffer.Any()) - { - buffer[offset] = _buffer.Dequeue(); - count--; - bytesRead++; - offset++; - } - - if (count == 0) - { - return bytesRead; - } - - // Fetch new data if the buffer is not full. - while (_webSocket.State == WebSocketState.Open) - { - await FetchChunkAsync(cancellationToken).ConfigureAwait(false); - - while (count > 0 && _buffer.Any()) - { - buffer[offset] = _buffer.Dequeue(); - count--; - bytesRead++; - offset++; - } - - if (count == 0) - { - return bytesRead; - } - } - - if (_webSocket.State == WebSocketState.Closed) - { - throw new MqttCommunicationException("WebSocket connection closed."); - } - - return bytesRead; - } - - public override void Write(byte[] buffer, int offset, int count) - { - WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - return _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } - - public override void SetLength(long value) - { - throw new NotSupportedException(); - } - - private async Task FetchChunkAsync(CancellationToken cancellationToken) - { - var response = await _webSocket.ReceiveAsync(new ArraySegment(_chunckBuffer, 0, _chunckBuffer.Length), cancellationToken).ConfigureAwait(false); - - if (response.MessageType == WebSocketMessageType.Close) - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); - } - else if (response.MessageType == WebSocketMessageType.Binary) - { - for (var i = 0; i < response.Count; i++) - { - _buffer.Enqueue(_chunckBuffer[i]); - } - } - else if (response.MessageType == WebSocketMessageType.Text) - { - throw new MqttProtocolViolationException("WebSocket channel received TEXT message."); - } - } - } -} diff --git a/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs b/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs new file mode 100644 index 0000000..ad20646 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs @@ -0,0 +1,26 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Internal +{ + public sealed class AsyncAutoResetEvent : IDisposable + { + private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(0, 1); + + public Task WaitOneAsync(CancellationToken cancellationToken) + { + return _semaphore.WaitAsync(cancellationToken); + } + + public void Set() + { + _semaphore.Release(); + } + + public void Dispose() + { + _semaphore?.Dispose(); + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs b/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs new file mode 100644 index 0000000..145e385 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs @@ -0,0 +1,26 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Internal +{ + public sealed class AsyncLock : IDisposable + { + private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); + + public Task EnterAsync(CancellationToken cancellationToken) + { + return _semaphore.WaitAsync(cancellationToken); + } + + public void Exit() + { + _semaphore.Release(); + } + + public void Dispose() + { + _semaphore?.Dispose(); + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs b/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs index b1d1a38..79eb2c4 100644 --- a/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Internal { - internal static class MqttApplicationMessageExtensions + public static class MqttApplicationMessageExtensions { public static MqttApplicationMessage ToApplicationMessage(this MqttPublishPacket publishPacket) { diff --git a/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs b/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs index 876c549..288ac0b 100644 --- a/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs @@ -7,72 +7,52 @@ namespace MQTTnet.Internal { public static class TaskExtensions { - public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + public static async Task TimeoutAfter(Func action, TimeSpan timeout, CancellationToken cancellationToken) { - if (task == null) throw new ArgumentNullException(nameof(task)); + if (action == null) throw new ArgumentNullException(nameof(action)); - using (var timeoutCts = new CancellationTokenSource()) + using (var timeoutCts = new CancellationTokenSource(timeout)) + using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) { try { - var timeoutTask = Task.Delay(timeout, timeoutCts.Token); - var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - - if (finishedTask == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsCanceled) + await action(linkedCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException exception) + { + var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; + if (timeoutReached) { - throw new TaskCanceledException(); + throw new MqttCommunicationTimedOutException(exception); } - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception?.GetBaseException()); - } - } - finally - { - timeoutCts.Cancel(); + throw; } } } - public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + public static async Task TimeoutAfter(Func> action, TimeSpan timeout, CancellationToken cancellationToken) { - if (task == null) throw new ArgumentNullException(nameof(task)); + if (action == null) throw new ArgumentNullException(nameof(action)); - using (var timeoutCts = new CancellationTokenSource()) + using (var timeoutCts = new CancellationTokenSource(timeout)) + using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) { try { - var timeoutTask = Task.Delay(timeout, timeoutCts.Token); - var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - - if (finishedTask == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsCanceled) - { - throw new TaskCanceledException(); - } - - if (task.IsFaulted) + return await action(linkedCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException exception) + { + var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; + if (timeoutReached) { - throw new MqttCommunicationException(task.Exception.GetBaseException()); + throw new MqttCommunicationTimedOutException(exception); } - return task.Result; - } - finally - { - timeoutCts.Cancel(); + throw; } } - } + } } } diff --git a/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs b/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs new file mode 100644 index 0000000..b380b08 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs @@ -0,0 +1,41 @@ +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Channel; + +namespace MQTTnet.Core.Internal +{ + public class TestMqttChannel : IMqttChannel + { + private readonly MemoryStream _stream; + + public TestMqttChannel(MemoryStream stream) + { + _stream = stream; + } + + public void Dispose() + { + } + + public Task ConnectAsync(CancellationToken cancellationToken) + { + return Task.FromResult(0); + } + + public Task DisconnectAsync() + { + return Task.FromResult(0); + } + + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _stream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _stream.WriteAsync(buffer, offset, count, cancellationToken); + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/MQTTnet.Netstandard.csproj b/Frameworks/MQTTnet.NetStandard/MQTTnet.Netstandard.csproj index cb154c2..fc1e41a 100644 --- a/Frameworks/MQTTnet.NetStandard/MQTTnet.Netstandard.csproj +++ b/Frameworks/MQTTnet.NetStandard/MQTTnet.Netstandard.csproj @@ -6,7 +6,6 @@ MQTTnet MQTTnet False - Full 0.0.0.0 0.0.0.0 0.0.0.0 @@ -19,7 +18,7 @@ false - + false UAP,Version=v10.0 UAP @@ -32,6 +31,10 @@ $(MSBuildExtensionsPath)\Microsoft\WindowsXaml\v$(VisualStudioVersion)\Microsoft.Windows.UI.Xaml.CSharp.targets + + Full + + @@ -43,17 +46,17 @@ - + - + - + diff --git a/Frameworks/MQTTnet.NetStandard/ManagedClient/ManagedMqttClient.cs b/Frameworks/MQTTnet.NetStandard/ManagedClient/ManagedMqttClient.cs index f1ce45d..bbfc1c7 100644 --- a/Frameworks/MQTTnet.NetStandard/ManagedClient/ManagedMqttClient.cs +++ b/Frameworks/MQTTnet.NetStandard/ManagedClient/ManagedMqttClient.cs @@ -71,7 +71,7 @@ namespace MQTTnet.ManagedClient _connectionCancellationToken = new CancellationTokenSource(); #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(async () => await MaintainConnectionAsync(_connectionCancellationToken.Token).ConfigureAwait(false), _connectionCancellationToken.Token).ConfigureAwait(false); + Task.Run(() => MaintainConnectionAsync(_connectionCancellationToken.Token), _connectionCancellationToken.Token); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed _logger.Info("Started"); @@ -190,10 +190,7 @@ namespace MQTTnet.ManagedClient if (connectionState == ReconnectionResult.Reconnected || _subscriptionsNotPushed) { await SynchronizeSubscriptionsAsync().ConfigureAwait(false); - StartPublishing(); - - return; } @@ -375,7 +372,7 @@ namespace MQTTnet.ManagedClient _publishingCancellationToken = cts; #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(async () => await PublishQueuedMessagesAsync(cts.Token).ConfigureAwait(false), cts.Token).ConfigureAwait(false); + Task.Run(() => PublishQueuedMessagesAsync(cts.Token), cts.Token); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } diff --git a/Frameworks/MQTTnet.NetStandard/MqttApplicationMessageBuilder.cs b/Frameworks/MQTTnet.NetStandard/MqttApplicationMessageBuilder.cs index ce949b8..79c65ec 100644 --- a/Frameworks/MQTTnet.NetStandard/MqttApplicationMessageBuilder.cs +++ b/Frameworks/MQTTnet.NetStandard/MqttApplicationMessageBuilder.cs @@ -32,7 +32,12 @@ namespace MQTTnet return this; } - public MqttApplicationMessageBuilder WithPayload(MemoryStream payload) + public MqttApplicationMessageBuilder WithPayload(Stream payload) + { + return WithPayload(payload, payload.Length - payload.Position); + } + + public MqttApplicationMessageBuilder WithPayload(Stream payload, long length) { if (payload == null) { @@ -46,7 +51,7 @@ namespace MQTTnet } else { - _payload = new byte[payload.Length - payload.Position]; + _payload = new byte[length]; payload.Read(_payload, 0, _payload.Length); } diff --git a/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs b/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs index 6c24a47..a4d8453 100644 --- a/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs +++ b/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Packets public override string ToString() { var topicFiltersText = string.Join(",", TopicFilters); - return "Subscribe: [PacketIdentifier=" + PacketIdentifier + "] [TopicFilters=" + topicFiltersText + "]"; + return "Unsubscribe: [PacketIdentifier=" + PacketIdentifier + "] [TopicFilters=" + topicFiltersText + "]"; } } } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs index 6577b0a..3688330 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.IO; using MQTTnet.Packets; @@ -9,8 +8,8 @@ namespace MQTTnet.Serializer { MqttProtocolVersion ProtocolVersion { get; set; } - ICollection> Serialize(MqttBasePacket mqttPacket); + ArraySegment Serialize(MqttBasePacket mqttPacket); - MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); + MqttBasePacket Deserialize(MqttPacketHeader header, Stream body); } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs index bd43f7d..4c43579 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs @@ -1,9 +1,9 @@ using System; -using System.Collections.Generic; using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; +using MQTTnet.Channel; using MQTTnet.Exceptions; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -13,7 +13,7 @@ namespace MQTTnet.Serializer public sealed class MqttPacketReader : BinaryReader { private readonly MqttPacketHeader _header; - + public MqttPacketReader(MqttPacketHeader header, Stream bodyStream) : base(bodyStream, Encoding.UTF8, true) { @@ -22,7 +22,7 @@ namespace MQTTnet.Serializer public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; - public static async Task ReadHeaderAsync(Stream stream, CancellationToken cancellationToken) + public static async Task ReadHeaderAsync(IMqttChannel stream, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { @@ -33,7 +33,7 @@ namespace MQTTnet.Serializer // some large delay and thus the thread should be put back to the pool (await). So ReadByte() // is not an option here. var buffer = new byte[1]; - var readCount = await stream.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); + var readCount = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); if (readCount <= 0) { return null; @@ -89,15 +89,14 @@ namespace MQTTnet.Serializer return ReadBytes(_header.BodyLength - (int)BaseStream.Position); } - private static async Task ReadBodyLengthAsync(Stream stream, CancellationToken cancellationToken) + private static async Task ReadBodyLengthAsync(IMqttChannel stream, CancellationToken cancellationToken) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. var multiplier = 1; var value = 0; - byte encodedByte; - + int encodedByte; var buffer = new byte[1]; - var readBytes = new List(); + do { if (cancellationToken.IsCancellationRequested) @@ -112,12 +111,11 @@ namespace MQTTnet.Serializer } encodedByte = buffer[0]; - readBytes.Add(encodedByte); value += (byte)(encodedByte & 127) * multiplier; if (multiplier > 128 * 128 * 128) { - throw new MqttProtocolViolationException($"Remaining length is invalid (Data={string.Join(",", readBytes)})."); + throw new MqttProtocolViolationException("Remaining length is invalid."); } multiplier *= 128; diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 0904c75..18818ce 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -2,7 +2,6 @@ using MQTTnet.Packets; using MQTTnet.Protocol; using System; -using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; @@ -12,37 +11,42 @@ namespace MQTTnet.Serializer public sealed class MqttPacketSerializer : IMqttPacketSerializer { private static byte[] ProtocolVersionV311Name { get; } = Encoding.UTF8.GetBytes("MQTT"); - private static byte[] ProtocolVersionV310Name { get; } = Encoding.UTF8.GetBytes("MQIs"); + private static byte[] ProtocolVersionV310Name { get; } = Encoding.UTF8.GetBytes("MQIsdp"); public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; - public ICollection> Serialize(MqttBasePacket packet) + public ArraySegment Serialize(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); using (var stream = new MemoryStream(128)) using (var writer = new MqttPacketWriter(stream)) { + // Leave enough head space for max header size (fixed + 4 variable remaining length) + stream.Position = 5; var fixedHeader = SerializePacket(packet, writer); - var remainingLength = (int)stream.Length; + + stream.Position = 1; + var remainingLength = MqttPacketWriter.EncodeRemainingLength((int)stream.Length - 5, stream); + + var headerSize = remainingLength + 1; + var headerOffset = 5 - headerSize; + + // Position cursor on correct offset on beginining of array (has leading 0x0) + stream.Position = headerOffset; + writer.Write(fixedHeader); - MqttPacketWriter.WriteRemainingLength(remainingLength, writer); - var headerLength = (int)stream.Length - remainingLength; #if NET461 || NET452 || NETSTANDARD2_0 var buffer = stream.GetBuffer(); #else var buffer = stream.ToArray(); #endif - return new List> - { - new ArraySegment(buffer, remainingLength, headerLength), - new ArraySegment(buffer, 0, remainingLength) - }; + return new ArraySegment(buffer, headerOffset, (int)stream.Length - headerOffset); } } - public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) + public MqttBasePacket Deserialize(MqttPacketHeader header, Stream body) { if (header == null) throw new ArgumentNullException(nameof(header)); if (body == null) throw new ArgumentNullException(nameof(body)); @@ -178,7 +182,7 @@ namespace MQTTnet.Serializer var topic = reader.ReadStringWithLengthPrefix(); - ushort packetIdentifier = 0; + ushort? packetIdentifier = null; if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { packetIdentifier = reader.ReadUInt16(); @@ -186,12 +190,12 @@ namespace MQTTnet.Serializer var packet = new MqttPublishPacket { + PacketIdentifier = packetIdentifier, Retain = retain, - QualityOfServiceLevel = qualityOfServiceLevel, - Dup = dup, Topic = topic, Payload = reader.ReadRemainingData(), - PacketIdentifier = packetIdentifier + QualityOfServiceLevel = qualityOfServiceLevel, + Dup = dup }; return packet; @@ -199,22 +203,30 @@ namespace MQTTnet.Serializer private static MqttBasePacket DeserializeConnect(MqttPacketReader reader) { - reader.ReadBytes(2); // Skip 2 bytes + reader.ReadBytes(2); // Skip 2 bytes for header and remaining length. MqttProtocolVersion protocolVersion; var protocolName = reader.ReadBytes(4); - if (protocolName.SequenceEqual(ProtocolVersionV310Name)) - { - reader.ReadBytes(2); - protocolVersion = MqttProtocolVersion.V310; - } - else if (protocolName.SequenceEqual(ProtocolVersionV311Name)) + + if (protocolName.SequenceEqual(ProtocolVersionV311Name)) { protocolVersion = MqttProtocolVersion.V311; } else { - throw new MqttProtocolViolationException("Protocol name is not supported."); + var buffer = new byte[6]; + Array.Copy(protocolName, buffer, 4); + protocolName = reader.ReadBytes(2); + Array.Copy(protocolName, 0, buffer, 4, 2); + + if (protocolName.SequenceEqual(ProtocolVersionV310Name)) + { + protocolVersion = MqttProtocolVersion.V310; + } + else + { + throw new MqttProtocolViolationException("Protocol name is not supported."); + } } reader.ReadByte(); // Skip protocol level @@ -293,7 +305,7 @@ namespace MQTTnet.Serializer return packet; } - + private static void ValidateConnectPacket(MqttConnectPacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); @@ -319,16 +331,15 @@ namespace MQTTnet.Serializer ValidateConnectPacket(packet); // Write variable header - writer.Write(0x00, 0x04); // 3.1.2.1 Protocol Name if (ProtocolVersion == MqttProtocolVersion.V311) { - writer.Write(ProtocolVersionV311Name); - writer.Write(0x04); // 3.1.2.2 Protocol Level (4) + writer.WriteWithLengthPrefix(ProtocolVersionV311Name); + writer.Write(0x04); // 3.1.2.2 Protocol Level 4 } else { - writer.Write(ProtocolVersionV310Name); - writer.Write(0x64, 0x70, 0x03); // Protocol Level (0x03) + writer.WriteWithLengthPrefix(ProtocolVersionV310Name); + writer.Write(0x03); // Protocol Level 3 } var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags @@ -347,6 +358,11 @@ namespace MQTTnet.Serializer connectFlags.Write(false); } + if (packet.Password != null && packet.Username == null) + { + throw new MqttProtocolViolationException("If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22]."); + } + connectFlags.Write(packet.Password != null); connectFlags.Write(packet.Username != null); diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs index cb3d458..54f40eb 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs @@ -55,14 +55,20 @@ namespace MQTTnet.Serializer Write(value); } - public static void WriteRemainingLength(int length, BinaryWriter target) + public static int EncodeRemainingLength(int length, MemoryStream stream) { - if (length == 0) + // write the encoded remaining length right aligned on the 4 byte buffer + + if (length <= 0) { - target.Write((byte)0); - return; + stream.Seek(3, SeekOrigin.Current); + stream.WriteByte(0); + return 1; } + var buffer = new byte[4]; + var offset = 0; + // Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. var x = length; do @@ -74,8 +80,15 @@ namespace MQTTnet.Serializer encodedByte = encodedByte | 128; } - target.Write((byte)encodedByte); + buffer[offset] = (byte)encodedByte; + + offset++; } while (x > 0); + + stream.Seek(4 - offset, SeekOrigin.Current); + stream.Write(buffer, 0, offset); + + return offset; } } } diff --git a/Frameworks/MQTTnet.NetStandard/Server/ConnectedMqttClient.cs b/Frameworks/MQTTnet.NetStandard/Server/ConnectedMqttClient.cs index 4fa8e4e..c12759a 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/ConnectedMqttClient.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/ConnectedMqttClient.cs @@ -3,11 +3,15 @@ using MQTTnet.Serializer; namespace MQTTnet.Server { + // TODO: Rename to "RegisteredClient" + // TODO: Add IsConnected + // TODO: Add interface + public class ConnectedMqttClient { public string ClientId { get; set; } - public MqttProtocolVersion ProtocolVersion { get; set; } + public MqttProtocolVersion? ProtocolVersion { get; set; } public TimeSpan LastPacketReceived { get; set; } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientDisconnectType.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientDisconnectType.cs new file mode 100644 index 0000000..19e9da8 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientDisconnectType.cs @@ -0,0 +1,8 @@ +namespace MQTTnet.Server +{ + public enum MqttClientDisconnectType + { + Clean, + NotClean + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientKeepAliveMonitor.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientKeepAliveMonitor.cs index 3f21536..8dc96e4 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientKeepAliveMonitor.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientKeepAliveMonitor.cs @@ -13,12 +13,12 @@ namespace MQTTnet.Server private readonly Stopwatch _lastNonKeepAlivePacketReceivedTracker = new Stopwatch(); private readonly string _clientId; - private readonly Func _timeoutCallback; + private readonly Action _timeoutCallback; private readonly IMqttNetLogger _logger; private Task _workerTask; - public MqttClientKeepAliveMonitor(string clientId, Func timeoutCallback, IMqttNetLogger logger) + public MqttClientKeepAliveMonitor(string clientId, Action timeoutCallback, IMqttNetLogger logger) { _clientId = clientId; _timeoutCallback = timeoutCallback; @@ -61,10 +61,7 @@ namespace MQTTnet.Server { _logger.Warning("Client '{0}': Did not receive any packet or keep alive signal.", _clientId); - if (_timeoutCallback != null) - { - await _timeoutCallback().ConfigureAwait(false); - } + _timeoutCallback?.Invoke(); return; } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs index f646a19..fd89691 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs @@ -1,11 +1,11 @@ using System; using System.Collections.Concurrent; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; +using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -14,7 +14,7 @@ namespace MQTTnet.Server public sealed class MqttClientPendingMessagesQueue : IDisposable { private readonly ConcurrentQueue _queue = new ConcurrentQueue(); - private readonly SemaphoreSlim _queueWaitSemaphore = new SemaphoreSlim(0); + private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); private readonly IMqttServerOptions _options; private readonly MqttClientSession _clientSession; private readonly IMqttNetLogger _logger; @@ -66,7 +66,7 @@ namespace MQTTnet.Server if (packet == null) throw new ArgumentNullException(nameof(packet)); _queue.Enqueue(packet); - _queueWaitSemaphore.Release(); + _queueAutoResetEvent.Set(); _logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId); } @@ -94,7 +94,7 @@ namespace MQTTnet.Server MqttBasePacket packet = null; try { - await _queueWaitSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false); if (!_queue.TryDequeue(out packet)) { throw new InvalidOperationException(); // should not happen @@ -105,7 +105,7 @@ namespace MQTTnet.Server return; } - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { packet }).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { packet }, cancellationToken).ConfigureAwait(false); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); } @@ -132,21 +132,21 @@ namespace MQTTnet.Server if (publishPacket.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { publishPacket.Dup = true; - _queue.Enqueue(packet); - _queueWaitSemaphore.Release(); + + Enqueue(publishPacket); } } if (!cancellationToken.IsCancellationRequested) { - await _clientSession.StopAsync().ConfigureAwait(false); + _clientSession.Stop(MqttClientDisconnectType.NotClean); } } } public void Dispose() { - _queueWaitSemaphore?.Dispose(); + _queueAutoResetEvent?.Dispose(); } } } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs index fbb11b9..7d50bad 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs @@ -16,11 +16,11 @@ namespace MQTTnet.Server public sealed class MqttClientSession : IDisposable { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); - private readonly IMqttServerOptions _options; - private readonly IMqttNetLogger _logger; private readonly MqttRetainedMessagesManager _retainedMessagesManager; + private readonly IMqttNetLogger _logger; + private readonly IMqttServerOptions _options; + private readonly MqttClientSessionsManager _sessionsManager; - private IMqttChannelAdapter _adapter; private CancellationTokenSource _cancellationTokenSource; private MqttApplicationMessage _willMessage; private bool _wasCleanDisconnect; @@ -28,22 +28,22 @@ namespace MQTTnet.Server public MqttClientSession( string clientId, IMqttServerOptions options, + MqttClientSessionsManager sessionsManager, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetLogger logger) { _options = options ?? throw new ArgumentNullException(nameof(options)); + _sessionsManager = sessionsManager; _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); ClientId = clientId; - KeepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, StopDueToKeepAliveTimeoutAsync, _logger); - SubscriptionsManager = new MqttClientSubscriptionsManager(_options, clientId); + KeepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, StopDueToKeepAliveTimeout, _logger); + SubscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server); PendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger); } - public Func ApplicationMessageReceivedCallback { get; set; } - public MqttClientSubscriptionsManager SubscriptionsManager { get; } public MqttClientPendingMessagesQueue PendingMessagesQueue { get; } @@ -52,9 +52,9 @@ namespace MQTTnet.Server public string ClientId { get; } - public MqttProtocolVersion? ProtocolVersion => _adapter?.PacketSerializer.ProtocolVersion; + public MqttProtocolVersion? ProtocolVersion { get; private set; } - public bool IsConnected => _adapter != null; + public bool IsConnected { get; private set; } public async Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) { @@ -63,34 +63,45 @@ namespace MQTTnet.Server try { - var cancellationTokenSource = new CancellationTokenSource(); + _cancellationTokenSource = new CancellationTokenSource(); _wasCleanDisconnect = false; _willMessage = connectPacket.WillMessage; - _adapter = adapter; - _cancellationTokenSource = cancellationTokenSource; - PendingMessagesQueue.Start(adapter, cancellationTokenSource.Token); - KeepAliveMonitor.Start(connectPacket.KeepAlivePeriod, cancellationTokenSource.Token); + IsConnected = true; + ProtocolVersion = adapter.PacketSerializer.ProtocolVersion; + + PendingMessagesQueue.Start(adapter, _cancellationTokenSource.Token); + KeepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token); - await ReceivePacketsAsync(adapter, cancellationTokenSource.Token).ConfigureAwait(false); + await ReceivePacketsAsync(adapter, _cancellationTokenSource.Token).ConfigureAwait(false); } catch (OperationCanceledException) { } catch (MqttCommunicationException exception) { - _logger.Warning(exception, "Client '{0}': Communication exception while processing client packets.", ClientId); + _logger.Warning(exception, + "Client '{0}': Communication exception while processing client packets.", ClientId); } catch (Exception exception) { - _logger.Error(exception, "Client '{0}': Unhandled exception while processing client packets.", ClientId); + _logger.Error(exception, + "Client '{0}': Unhandled exception while processing client packets.", ClientId); + } + finally + { + ProtocolVersion = null; + IsConnected = false; + + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; } return _wasCleanDisconnect; } - public async Task StopAsync(bool wasCleanDisconnect = false) + public void Stop(MqttClientDisconnectType type) { try { @@ -99,37 +110,31 @@ namespace MQTTnet.Server return; } - _wasCleanDisconnect = wasCleanDisconnect; + _wasCleanDisconnect = type == MqttClientDisconnectType.Clean; _cancellationTokenSource?.Cancel(false); - PendingMessagesQueue.WaitForCompletion(); KeepAliveMonitor.WaitForCompletion(); - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; - - _adapter = null; - - _logger.Info("Client '{0}': Session stopped.", ClientId); - } - finally - { var willMessage = _willMessage; _willMessage = null; // clear willmessage so it is send just once - if (willMessage != null && !wasCleanDisconnect) + if (willMessage != null && !_wasCleanDisconnect) { - await ApplicationMessageReceivedCallback(this, willMessage).ConfigureAwait(false); + _sessionsManager.StartDispatchApplicationMessage(this, willMessage); } } + finally + { + _logger.Info("Client '{0}': Session stopped.", ClientId); + } } public async Task EnqueueApplicationMessageAsync(MqttApplicationMessage applicationMessage) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - var result = await SubscriptionsManager.CheckSubscriptionsAsync(applicationMessage); + var result = await SubscriptionsManager.CheckSubscriptionsAsync(applicationMessage).ConfigureAwait(false); if (!result.IsSubscribed) { return; @@ -153,10 +158,10 @@ namespace MQTTnet.Server { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - await SubscriptionsManager.SubscribeAsync(new MqttSubscribePacket + SubscriptionsManager.Subscribe(new MqttSubscribePacket { TopicFilters = topicFilters - }).ConfigureAwait(false); + }); await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false); } @@ -165,26 +170,26 @@ namespace MQTTnet.Server { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return SubscriptionsManager.UnsubscribeAsync(new MqttUnsubscribePacket + SubscriptionsManager.Unsubscribe(new MqttUnsubscribePacket { TopicFilters = topicFilters }); + + return Task.FromResult(0); } public void Dispose() { - ApplicationMessageReceivedCallback = null; - SubscriptionsManager?.Dispose(); PendingMessagesQueue?.Dispose(); _cancellationTokenSource?.Dispose(); } - private Task StopDueToKeepAliveTimeoutAsync() + private void StopDueToKeepAliveTimeout() { _logger.Info("Client '{0}': Timeout while waiting for KeepAlive packet.", ClientId); - return StopAsync(); + Stop(MqttClientDisconnectType.NotClean); } private async Task ReceivePacketsAsync(IMqttChannelAdapter adapter, CancellationToken cancellationToken) @@ -201,15 +206,18 @@ namespace MQTTnet.Server catch (OperationCanceledException) { } - catch (MqttCommunicationException exception) - { - _logger.Warning(exception, "Client '{0}': Communication exception while processing client packets.", ClientId); - await StopAsync().ConfigureAwait(false); - } catch (Exception exception) { - _logger.Error(exception, "Client '{0}': Unhandled exception while processing client packets.", ClientId); - await StopAsync().ConfigureAwait(false); + if (exception is MqttCommunicationException) + { + _logger.Warning(exception, "Client '{0}': Communication exception while processing client packets.", ClientId); + } + else + { + _logger.Error(exception, "Client '{0}': Unhandled exception while processing client packets.", ClientId); + } + + Stop(MqttClientDisconnectType.NotClean); } } @@ -222,7 +230,7 @@ namespace MQTTnet.Server if (packet is MqttPingReqPacket) { - return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { new MqttPingRespPacket() }); + return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { new MqttPingRespPacket() }, cancellationToken); } if (packet is MqttPubRelPacket pubRelPacket) @@ -237,7 +245,7 @@ namespace MQTTnet.Server PacketIdentifier = pubRecPacket.PacketIdentifier }; - return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { responsePacket }); + return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { responsePacket }, cancellationToken); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) @@ -258,26 +266,40 @@ namespace MQTTnet.Server if (packet is MqttDisconnectPacket) { - return StopAsync(true); + Stop(MqttClientDisconnectType.Clean); + return Task.FromResult(0); } if (packet is MqttConnectPacket) { - return StopAsync(); + Stop(MqttClientDisconnectType.NotClean); + return Task.FromResult(0); } _logger.Warning("Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); - return StopAsync(); + Stop(MqttClientDisconnectType.NotClean); + + return Task.FromResult(0); + } + + private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) + { + var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters); + foreach (var applicationMessage in retainedMessages) + { + await EnqueueApplicationMessageAsync(applicationMessage).ConfigureAwait(false); + } } private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { - var subscribeResult = await SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { subscribeResult.ResponsePacket }).ConfigureAwait(false); + var subscribeResult = SubscriptionsManager.Subscribe(subscribePacket); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { subscribeResult.ResponsePacket }, cancellationToken).ConfigureAwait(false); if (subscribeResult.CloseConnection) { - await StopAsync().ConfigureAwait(false); + Stop(MqttClientDisconnectType.NotClean); + return; } await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); @@ -285,17 +307,8 @@ namespace MQTTnet.Server private async Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { - var unsubscribeResult = await SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { unsubscribeResult }); - } - - private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) - { - var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters); - foreach (var applicationMessage in retainedMessages) - { - await EnqueueApplicationMessageAsync(applicationMessage).ConfigureAwait(false); - } + var unsubscribeResult = SubscriptionsManager.Unsubscribe(unsubscribePacket); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { unsubscribeResult }, cancellationToken); } private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) @@ -306,7 +319,8 @@ namespace MQTTnet.Server { case MqttQualityOfServiceLevel.AtMostOnce: { - return ApplicationMessageReceivedCallback?.Invoke(this, applicationMessage); + _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); + return Task.FromResult(0); } case MqttQualityOfServiceLevel.AtLeastOnce: { @@ -325,25 +339,25 @@ namespace MQTTnet.Server private async Task HandleIncomingPublishPacketWithQoS1(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - await ApplicationMessageReceivedCallback(this, applicationMessage).ConfigureAwait(false); + _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { response }).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { response }, cancellationToken).ConfigureAwait(false); } private async Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { // QoS 2 is implement as method "B" [4.3.3 QoS 2: Exactly once delivery] - await ApplicationMessageReceivedCallback(this, applicationMessage).ConfigureAwait(false); + _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { response }).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { response }, cancellationToken).ConfigureAwait(false); } private Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }; - return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { response }); + return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { response }, cancellationToken); } } } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs index 8dd9d06..98e5835 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs @@ -8,38 +8,34 @@ using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Packets; using MQTTnet.Protocol; -using MQTTnet.Serializer; namespace MQTTnet.Server { public sealed class MqttClientSessionsManager : IDisposable { private readonly Dictionary _sessions = new Dictionary(); - private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); + private readonly SemaphoreSlim _sessionsLock = new SemaphoreSlim(1, 1); - private readonly IMqttServerOptions _options; private readonly MqttRetainedMessagesManager _retainedMessagesManager; + private readonly IMqttServerOptions _options; private readonly IMqttNetLogger _logger; - public MqttClientSessionsManager(IMqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetLogger logger) + public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetLogger logger) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _options = options ?? throw new ArgumentNullException(nameof(options)); + Server = server ?? throw new ArgumentNullException(nameof(server)); _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); } - public Action ClientConnectedCallback { get; set; } - public Action ClientDisconnectedCallback { get; set; } - public Action ClientSubscribedTopicCallback { get; set; } - public Action ClientUnsubscribedTopicCallback { get; set; } - public Action ApplicationMessageReceivedCallback { get; set; } + public MqttServer Server { get; } public async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) { var clientId = string.Empty; var wasCleanDisconnect = false; MqttClientSession clientSession = null; - + try { if (!(await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken) @@ -57,30 +53,30 @@ namespace MQTTnet.Server var connectReturnCode = ValidateConnection(connectPacket); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] + await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { new MqttConnAckPacket { ConnectReturnCode = connectReturnCode } - }).ConfigureAwait(false); + }, cancellationToken).ConfigureAwait(false); return; } - var result = await GetOrCreateClientSessionAsync(connectPacket).ConfigureAwait(false); + var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); clientSession = result.Session; - await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] + await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { new MqttConnAckPacket { ConnectReturnCode = connectReturnCode, IsSessionPresent = result.IsExistingSession } - }).ConfigureAwait(false); + }, cancellationToken).ConfigureAwait(false); - ClientConnectedCallback?.Invoke(new ConnectedMqttClient + Server.OnClientConnected(new ConnectedMqttClient { ClientId = clientId, ProtocolVersion = clientAdapter.PacketSerializer.ProtocolVersion @@ -99,15 +95,15 @@ namespace MQTTnet.Server { try { - await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); clientAdapter.Dispose(); } catch (Exception exception) { _logger.Error(exception, exception.Message); } - - ClientDisconnectedCallback?.Invoke(new ConnectedMqttClient + + Server.OnClientDisconnected(new ConnectedMqttClient { ClientId = clientId, ProtocolVersion = clientAdapter.PacketSerializer.ProtocolVersion, @@ -119,31 +115,31 @@ namespace MQTTnet.Server public async Task StopAsync() { - await _semaphore.WaitAsync().ConfigureAwait(false); + await _sessionsLock.WaitAsync().ConfigureAwait(false); try { foreach (var session in _sessions) { - await session.Value.StopAsync().ConfigureAwait(false); + session.Value.Stop(MqttClientDisconnectType.NotClean); } _sessions.Clear(); } finally { - _semaphore.Release(); + _sessionsLock.Release(); } } public async Task> GetConnectedClientsAsync() { - await _semaphore.WaitAsync().ConfigureAwait(false); + await _sessionsLock.WaitAsync().ConfigureAwait(false); try { return _sessions.Where(s => s.Value.IsConnected).Select(s => new ConnectedMqttClient { ClientId = s.Value.ClientId, - ProtocolVersion = s.Value.ProtocolVersion ?? MqttProtocolVersion.V311, + ProtocolVersion = s.Value.ProtocolVersion, LastPacketReceived = s.Value.KeepAliveMonitor.LastPacketReceived, LastNonKeepAlivePacketReceived = s.Value.KeepAliveMonitor.LastNonKeepAlivePacketReceived, PendingApplicationMessages = s.Value.PendingMessagesQueue.Count @@ -151,49 +147,13 @@ namespace MQTTnet.Server } finally { - _semaphore.Release(); + _sessionsLock.Release(); } } - public async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) { - try - { - var interceptorContext = InterceptApplicationMessage(senderClientSession, applicationMessage); - if (interceptorContext.CloseConnection) - { - await senderClientSession.StopAsync().ConfigureAwait(false); - } - - if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) - { - return; - } - - if (applicationMessage.Retain) - { - await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false); - } - - ApplicationMessageReceivedCallback?.Invoke(senderClientSession?.ClientId, applicationMessage); - } - catch (Exception exception) - { - _logger.Error(exception, "Error while processing application message"); - } - - await _semaphore.WaitAsync().ConfigureAwait(false); - try - { - foreach (var clientSession in _sessions.Values) - { - await clientSession.EnqueueApplicationMessageAsync(applicationMessage); - } - } - finally - { - _semaphore.Release(); - } + Task.Run(() => DispatchApplicationMessageAsync(senderClientSession, applicationMessage)); } public async Task SubscribeAsync(string clientId, IList topicFilters) @@ -201,19 +161,19 @@ namespace MQTTnet.Server if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - await _semaphore.WaitAsync().ConfigureAwait(false); + await _sessionsLock.WaitAsync().ConfigureAwait(false); try { if (!_sessions.TryGetValue(clientId, out var session)) { - throw new InvalidOperationException($"Client session {clientId} is unknown."); + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); } - await session.SubscribeAsync(topicFilters); + await session.SubscribeAsync(topicFilters).ConfigureAwait(false); } finally { - _semaphore.Release(); + _sessionsLock.Release(); } } @@ -222,36 +182,25 @@ namespace MQTTnet.Server if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - await _semaphore.WaitAsync().ConfigureAwait(false); + await _sessionsLock.WaitAsync().ConfigureAwait(false); try { if (!_sessions.TryGetValue(clientId, out var session)) { - throw new InvalidOperationException($"Client session {clientId} is unknown."); + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); } - await session.UnsubscribeAsync(topicFilters); + await session.UnsubscribeAsync(topicFilters).ConfigureAwait(false); } finally { - _semaphore.Release(); + _sessionsLock.Release(); } } - private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + public void Dispose() { - var interceptorContext = new MqttApplicationMessageInterceptorContext( - senderClientSession?.ClientId, - applicationMessage); - - var interceptor = _options.ApplicationMessageInterceptor; - if (interceptor == null) - { - return interceptorContext; - } - - interceptor(interceptorContext); - return interceptorContext; + _sessionsLock?.Dispose(); } private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) @@ -271,9 +220,9 @@ namespace MQTTnet.Server return context.ReturnCode; } - private async Task GetOrCreateClientSessionAsync(MqttConnectPacket connectPacket) + private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) { - await _semaphore.WaitAsync().ConfigureAwait(false); + await _sessionsLock.WaitAsync().ConfigureAwait(false); try { var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); @@ -283,7 +232,7 @@ namespace MQTTnet.Server { _sessions.Remove(connectPacket.ClientId); - await clientSession.StopAsync().ConfigureAwait(false); + clientSession.Stop(MqttClientDisconnectType.Clean); clientSession.Dispose(); clientSession = null; @@ -300,14 +249,7 @@ namespace MQTTnet.Server { isExistingSession = false; - clientSession = new MqttClientSession(connectPacket.ClientId, _options, _retainedMessagesManager, _logger) - { - ApplicationMessageReceivedCallback = DispatchApplicationMessageAsync - }; - - clientSession.SubscriptionsManager.TopicSubscribedCallback = ClientSubscribedTopicCallback; - clientSession.SubscriptionsManager.TopicUnsubscribedCallback = ClientUnsubscribedTopicCallback; - + clientSession = new MqttClientSession(connectPacket.ClientId, _options, this, _retainedMessagesManager, _logger); _sessions[connectPacket.ClientId] = clientSession; _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); @@ -317,19 +259,65 @@ namespace MQTTnet.Server } finally { - _semaphore.Release(); + _sessionsLock.Release(); } } - public void Dispose() + private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + { + try + { + var interceptorContext = InterceptApplicationMessage(senderClientSession, applicationMessage); + if (interceptorContext.CloseConnection) + { + senderClientSession.Stop(MqttClientDisconnectType.NotClean); + } + + if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) + { + return; + } + + if (applicationMessage.Retain) + { + await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false); + } + + Server.OnApplicationMessageReceived(senderClientSession?.ClientId, applicationMessage); + } + catch (Exception exception) + { + _logger.Error(exception, "Error while processing application message"); + } + + await _sessionsLock.WaitAsync().ConfigureAwait(false); + try + { + foreach (var clientSession in _sessions.Values) + { + await clientSession.EnqueueApplicationMessageAsync(applicationMessage).ConfigureAwait(false); + } + } + finally + { + _sessionsLock.Release(); + } + } + + private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) { - ClientConnectedCallback = null; - ClientDisconnectedCallback = null; - ClientSubscribedTopicCallback = null; - ClientUnsubscribedTopicCallback = null; - ApplicationMessageReceivedCallback = null; + var interceptorContext = new MqttApplicationMessageInterceptorContext( + senderClientSession?.ClientId, + applicationMessage); + + var interceptor = _options.ApplicationMessageInterceptor; + if (interceptor == null) + { + return interceptorContext; + } - _semaphore?.Dispose(); + interceptor(interceptorContext); + return interceptorContext; } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs index e4ee921..dfb9463 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading; @@ -10,21 +11,20 @@ namespace MQTTnet.Server { public sealed class MqttClientSubscriptionsManager : IDisposable { + private readonly ConcurrentDictionary _subscriptions = new ConcurrentDictionary(); private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); - private readonly Dictionary _subscriptions = new Dictionary(); private readonly IMqttServerOptions _options; + private readonly MqttServer _server; private readonly string _clientId; - public MqttClientSubscriptionsManager(IMqttServerOptions options, string clientId) + public MqttClientSubscriptionsManager(string clientId, IMqttServerOptions options, MqttServer server) { - _options = options ?? throw new ArgumentNullException(nameof(options)); _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + _server = server; } - public Action TopicSubscribedCallback { get; set; } - public Action TopicUnsubscribedCallback { get; set; } - - public async Task SubscribeAsync(MqttSubscribePacket subscribePacket) + public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket) { if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); @@ -38,57 +38,41 @@ namespace MQTTnet.Server CloseConnection = false }; - await _semaphore.WaitAsync().ConfigureAwait(false); - try + foreach (var topicFilter in subscribePacket.TopicFilters) { - foreach (var topicFilter in subscribePacket.TopicFilters) + var interceptorContext = InterceptSubscribe(topicFilter); + if (!interceptorContext.AcceptSubscription) { - var interceptorContext = InterceptSubscribe(topicFilter); - if (!interceptorContext.AcceptSubscription) - { - result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure); - } - else - { - result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel)); - } + result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure); + } + else + { + result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel)); + } - if (interceptorContext.CloseConnection) - { - result.CloseConnection = true; - } + if (interceptorContext.CloseConnection) + { + result.CloseConnection = true; + } - if (interceptorContext.AcceptSubscription) - { - _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; - TopicSubscribedCallback?.Invoke(_clientId, topicFilter); - } + if (interceptorContext.AcceptSubscription) + { + _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + _server.OnClientSubscribedTopic(_clientId, topicFilter); } } - finally - { - _semaphore.Release(); - } return result; } - public async Task UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket) + public MqttUnsubAckPacket Unsubscribe(MqttUnsubscribePacket unsubscribePacket) { if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket)); - await _semaphore.WaitAsync().ConfigureAwait(false); - try + foreach (var topicFilter in unsubscribePacket.TopicFilters) { - foreach (var topicFilter in unsubscribePacket.TopicFilters) - { - _subscriptions.Remove(topicFilter); - TopicUnsubscribedCallback?.Invoke(_clientId, topicFilter); - } - } - finally - { - _semaphore.Release(); + _subscriptions.TryRemove(topicFilter, out _); + _server.OnClientUnsubscribedTopic(_clientId, topicFilter); } return new MqttUnsubAckPacket diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs index 6329442..b32869d 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs @@ -57,7 +57,7 @@ namespace MQTTnet.Server return _clientSessionsManager.UnsubscribeAsync(clientId, topicFilters); } - public async Task PublishAsync(IEnumerable applicationMessages) + public Task PublishAsync(IEnumerable applicationMessages) { if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); @@ -65,8 +65,10 @@ namespace MQTTnet.Server foreach (var applicationMessage in applicationMessages) { - await _clientSessionsManager.DispatchApplicationMessageAsync(null, applicationMessage); + _clientSessionsManager.StartDispatchApplicationMessage(null, applicationMessage); } + + return Task.FromResult(0); } public async Task StartAsync(IMqttServerOptions options) @@ -80,14 +82,7 @@ namespace MQTTnet.Server _retainedMessagesManager = new MqttRetainedMessagesManager(Options, _logger); await _retainedMessagesManager.LoadMessagesAsync(); - _clientSessionsManager = new MqttClientSessionsManager(Options, _retainedMessagesManager, _logger) - { - ClientConnectedCallback = OnClientConnected, - ClientDisconnectedCallback = OnClientDisconnected, - ClientSubscribedTopicCallback = OnClientSubscribedTopic, - ClientUnsubscribedTopicCallback = OnClientUnsubscribedTopic, - ApplicationMessageReceivedCallback = OnApplicationMessageReceived - }; + _clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _logger); foreach (var adapter in _adapters) { @@ -132,29 +127,29 @@ namespace MQTTnet.Server } } - private void OnClientConnected(ConnectedMqttClient client) + internal void OnClientConnected(ConnectedMqttClient client) { _logger.Info("Client '{0}': Connected.", client.ClientId); ClientConnected?.Invoke(this, new MqttClientConnectedEventArgs(client)); } - private void OnClientDisconnected(ConnectedMqttClient client, bool wasCleanDisconnect) + internal void OnClientDisconnected(ConnectedMqttClient client, bool wasCleanDisconnect) { _logger.Info("Client '{0}': Disconnected (clean={1}).", client.ClientId, wasCleanDisconnect); ClientDisconnected?.Invoke(this, new MqttClientDisconnectedEventArgs(client, wasCleanDisconnect)); } - private void OnClientSubscribedTopic(string clientId, TopicFilter topicFilter) + internal void OnClientSubscribedTopic(string clientId, TopicFilter topicFilter) { ClientSubscribedTopic?.Invoke(this, new MqttClientSubscribedTopicEventArgs(clientId, topicFilter)); } - private void OnClientUnsubscribedTopic(string clientId, string topicFilter) + internal void OnClientUnsubscribedTopic(string clientId, string topicFilter) { ClientUnsubscribedTopic?.Invoke(this, new MqttClientUnsubscribedTopicEventArgs(clientId, topicFilter)); } - private void OnApplicationMessageReceived(string clientId, MqttApplicationMessage applicationMessage) + internal void OnApplicationMessageReceived(string clientId, MqttApplicationMessage applicationMessage) { ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(clientId, applicationMessage)); } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttTopicFilterComparer.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttTopicFilterComparer.cs index 8244415..00e0bd6 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttTopicFilterComparer.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttTopicFilterComparer.cs @@ -19,12 +19,17 @@ namespace MQTTnet.Server var fragmentsTopic = topic.Split(TopicLevelSeparator, StringSplitOptions.None); var fragmentsFilter = filter.Split(TopicLevelSeparator, StringSplitOptions.None); + // # > In either case it MUST be the last character specified in the Topic Filter [MQTT-4.7.1-2]. for (var i = 0; i < fragmentsFilter.Length; i++) { - switch (fragmentsFilter[i]) + if (fragmentsFilter[i] == "+") { - case "+": continue; - case "#" when i == fragmentsFilter.Length - 1: return true; + continue; + } + + if (fragmentsFilter[i] == "#") + { + return true; } if (i >= fragmentsTopic.Length) @@ -38,7 +43,7 @@ namespace MQTTnet.Server } } - return fragmentsTopic.Length <= fragmentsFilter.Length; + return fragmentsTopic.Length == fragmentsFilter.Length; } } } diff --git a/Frameworks/MQTTnet.Netstandard/MQTTnet.NetStandard.csproj b/Frameworks/MQTTnet.Netstandard/MQTTnet.NetStandard.csproj index cb154c2..fc1e41a 100644 --- a/Frameworks/MQTTnet.Netstandard/MQTTnet.NetStandard.csproj +++ b/Frameworks/MQTTnet.Netstandard/MQTTnet.NetStandard.csproj @@ -6,7 +6,6 @@ MQTTnet MQTTnet False - Full 0.0.0.0 0.0.0.0 0.0.0.0 @@ -19,7 +18,7 @@ false - + false UAP,Version=v10.0 UAP @@ -32,6 +31,10 @@ $(MSBuildExtensionsPath)\Microsoft\WindowsXaml\v$(VisualStudioVersion)\Microsoft.Windows.UI.Xaml.CSharp.targets + + Full + + @@ -43,17 +46,17 @@ - + - + - + diff --git a/MQTTnet.sln b/MQTTnet.sln index 2b31673..a9bd645 100644 --- a/MQTTnet.sln +++ b/MQTTnet.sln @@ -38,7 +38,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Extensions", "Extensions", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet.Extensions.Rpc", "Extensions\MQTTnet.Extensions.Rpc\MQTTnet.Extensions.Rpc.csproj", "{C444E9C8-95FA-430E-9126-274129DE16CD}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MQTTserver", "MQTTserver\MQTTserver.csproj", "{5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MQTTnet.Benchmarks", "Tests\MQTTnet.Benchmarks\MQTTnet.Benchmarks.csproj", "{998D04DD-7CB0-45F5-A393-E2495C16399E}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -170,22 +170,22 @@ Global {C444E9C8-95FA-430E-9126-274129DE16CD}.Release|x64.Build.0 = Release|Any CPU {C444E9C8-95FA-430E-9126-274129DE16CD}.Release|x86.ActiveCfg = Release|Any CPU {C444E9C8-95FA-430E-9126-274129DE16CD}.Release|x86.Build.0 = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|Any CPU.Build.0 = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|ARM.ActiveCfg = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|ARM.Build.0 = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|x64.ActiveCfg = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|x64.Build.0 = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|x86.ActiveCfg = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Debug|x86.Build.0 = Debug|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|Any CPU.ActiveCfg = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|Any CPU.Build.0 = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|ARM.ActiveCfg = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|ARM.Build.0 = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|x64.ActiveCfg = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|x64.Build.0 = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|x86.ActiveCfg = Release|Any CPU - {5FCCD9CE-9E7E-40C1-9B99-3328FED9EED7}.Release|x86.Build.0 = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|ARM.ActiveCfg = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|ARM.Build.0 = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|x64.ActiveCfg = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|x64.Build.0 = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|x86.ActiveCfg = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Debug|x86.Build.0 = Debug|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|Any CPU.Build.0 = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|ARM.ActiveCfg = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|ARM.Build.0 = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|x64.ActiveCfg = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|x64.Build.0 = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|x86.ActiveCfg = Release|Any CPU + {998D04DD-7CB0-45F5-A393-E2495C16399E}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -198,6 +198,7 @@ Global {C6FF8AEA-0855-41EC-A1F3-AC262225BAB9} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} {F10C4060-F7EE-4A83-919F-FF723E72F94A} = {32A630A7-2598-41D7-B625-204CD906F5FB} {C444E9C8-95FA-430E-9126-274129DE16CD} = {12816BCC-AF9E-44A9-9AE5-C246AF2A0587} + {998D04DD-7CB0-45F5-A393-E2495C16399E} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {07536672-5CBC-4BE3-ACE0-708A431A7894} diff --git a/README.md b/README.md index 6a3111d..cff5490 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,10 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov * TLS 1.2 support for client and server (but not UWP servers) * Extensible communication channels (i.e. In-Memory, TCP, TCP+TLS, WS) * Lightweight (only the low level implementation of MQTT, no overhead) -* Performance optimized (processing ~40.000 messages / second)* +* Performance optimized (processing ~60.000 messages / second)* * Interfaces included for mocking and testing * Access to internal trace messages -* Unit tested (70+ tests) +* Unit tested (~80 tests) \* Tested on local machine (Intel i7 8700K) with MQTTnet client and server running in the same process using the TCP channel. The app for verification is part of this repository and stored in _/Tests/MQTTnet.TestApp.NetCore_. @@ -86,7 +86,7 @@ If you use this library and want to see your project here please let me know. ## MIT License -Copyright (c) 2017 Christian Kratky +Copyright (c) 2017-2018 Christian Kratky Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Tests/MQTTnet.Benchmarks/App.config b/Tests/MQTTnet.Benchmarks/App.config new file mode 100644 index 0000000..5a41bc2 --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/App.config @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj new file mode 100644 index 0000000..3b52bba --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -0,0 +1,170 @@ + + + + + Debug + AnyCPU + {998D04DD-7CB0-45F5-A393-E2495C16399E} + Exe + MQTTnet.Benchmarks + MQTTnet.Benchmarks + v4.6.2 + 512 + true + + + AnyCPU + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + AnyCPU + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + + ..\..\packages\BenchmarkDotNet.0.10.14\lib\net46\BenchmarkDotNet.dll + + + ..\..\packages\BenchmarkDotNet.Core.0.10.14\lib\net46\BenchmarkDotNet.Core.dll + + + ..\..\packages\BenchmarkDotNet.Toolchains.Roslyn.0.10.14\lib\net46\BenchmarkDotNet.Toolchains.Roslyn.dll + + + ..\..\packages\Microsoft.CodeAnalysis.Common.2.6.1\lib\netstandard1.3\Microsoft.CodeAnalysis.dll + + + ..\..\packages\Microsoft.CodeAnalysis.CSharp.2.6.1\lib\netstandard1.3\Microsoft.CodeAnalysis.CSharp.dll + + + ..\..\packages\Microsoft.DotNet.InternalAbstractions.1.0.0\lib\net451\Microsoft.DotNet.InternalAbstractions.dll + + + ..\..\packages\Microsoft.DotNet.PlatformAbstractions.1.1.1\lib\net451\Microsoft.DotNet.PlatformAbstractions.dll + + + ..\..\packages\Microsoft.Win32.Registry.4.3.0\lib\net46\Microsoft.Win32.Registry.dll + + + + ..\..\packages\System.AppContext.4.3.0\lib\net46\System.AppContext.dll + True + + + ..\..\packages\System.Collections.Immutable.1.3.1\lib\portable-net45+win8+wp8+wpa81\System.Collections.Immutable.dll + True + + + + ..\..\packages\System.Console.4.3.0\lib\net46\System.Console.dll + + + + ..\..\packages\System.Diagnostics.FileVersionInfo.4.3.0\lib\net46\System.Diagnostics.FileVersionInfo.dll + + + ..\..\packages\System.Diagnostics.StackTrace.4.3.0\lib\net46\System.Diagnostics.StackTrace.dll + + + ..\..\packages\System.IO.Compression.4.3.0\lib\net46\System.IO.Compression.dll + True + + + ..\..\packages\System.IO.FileSystem.4.3.0\lib\net46\System.IO.FileSystem.dll + + + ..\..\packages\System.IO.FileSystem.Primitives.4.3.0\lib\net46\System.IO.FileSystem.Primitives.dll + + + + + ..\..\packages\System.Reflection.4.3.0\lib\net462\System.Reflection.dll + + + ..\..\packages\System.Reflection.Metadata.1.4.2\lib\portable-net45+win8\System.Reflection.Metadata.dll + + + ..\..\packages\System.Runtime.4.3.0\lib\net462\System.Runtime.dll + + + ..\..\packages\System.Runtime.Extensions.4.3.0\lib\net462\System.Runtime.Extensions.dll + + + ..\..\packages\System.Runtime.InteropServices.4.3.0\lib\net462\System.Runtime.InteropServices.dll + + + ..\..\packages\System.Security.Cryptography.Algorithms.4.3.0\lib\net461\System.Security.Cryptography.Algorithms.dll + + + ..\..\packages\System.Security.Cryptography.Encoding.4.3.0\lib\net46\System.Security.Cryptography.Encoding.dll + + + ..\..\packages\System.Security.Cryptography.Primitives.4.3.0\lib\net46\System.Security.Cryptography.Primitives.dll + + + ..\..\packages\System.Security.Cryptography.X509Certificates.4.3.0\lib\net461\System.Security.Cryptography.X509Certificates.dll + + + ..\..\packages\System.Text.Encoding.CodePages.4.3.0\lib\net46\System.Text.Encoding.CodePages.dll + + + ..\..\packages\System.Threading.Tasks.Extensions.4.3.0\lib\portable-net45+win8+wp8+wpa81\System.Threading.Tasks.Extensions.dll + + + ..\..\packages\System.Threading.Thread.4.3.0\lib\net46\System.Threading.Thread.dll + + + ..\..\packages\System.ValueTuple.4.3.0\lib\netstandard1.0\System.ValueTuple.dll + + + + + + + + + ..\..\packages\System.Xml.ReaderWriter.4.3.0\lib\net46\System.Xml.ReaderWriter.dll + + + ..\..\packages\System.Xml.XmlDocument.4.3.0\lib\net46\System.Xml.XmlDocument.dll + + + ..\..\packages\System.Xml.XPath.4.3.0\lib\net46\System.Xml.XPath.dll + + + ..\..\packages\System.Xml.XPath.XDocument.4.3.0\lib\net46\System.Xml.XPath.XDocument.dll + + + + + + + + + + + + + + + + + + + {3587E506-55A2-4EB3-99C7-DC01E42D25D2} + MQTTnet.NetStandard + + + + \ No newline at end of file diff --git a/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs new file mode 100644 index 0000000..53b71fd --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs @@ -0,0 +1,47 @@ +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Attributes.Columns; +using BenchmarkDotNet.Attributes.Exporters; +using BenchmarkDotNet.Attributes.Jobs; +using MQTTnet.Client; +using MQTTnet.Server; + +namespace MQTTnet.Benchmarks +{ + [ClrJob] + [RPlotExporter, RankColumn] + public class MessageProcessingBenchmark + { + private IMqttServer _mqttServer; + private IMqttClient _mqttClient; + private MqttApplicationMessage _message; + + [GlobalSetup] + public void Setup() + { + var factory = new MqttFactory(); + _mqttServer = factory.CreateMqttServer(); + _mqttClient = factory.CreateMqttClient(); + + var serverOptions = new MqttServerOptionsBuilder().Build(); + _mqttServer.StartAsync(serverOptions).GetAwaiter().GetResult(); + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost").Build(); + + _mqttClient.ConnectAsync(clientOptions).GetAwaiter().GetResult(); + + _message = new MqttApplicationMessageBuilder() + .WithTopic("A") + .Build(); + } + + [Benchmark] + public void Send_10000_Messages() + { + for (var i = 0; i < 10000; i++) + { + _mqttClient.PublishAsync(_message).GetAwaiter().GetResult(); + } + } + } +} diff --git a/Tests/MQTTnet.Benchmarks/Program.cs b/Tests/MQTTnet.Benchmarks/Program.cs new file mode 100644 index 0000000..ad2c363 --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/Program.cs @@ -0,0 +1,32 @@ +using System; +using System.Threading; +using BenchmarkDotNet.Running; +using MQTTnet.Diagnostics; + +namespace MQTTnet.Benchmarks +{ + public static class Program + { + public static void Main(string[] args) + { + Console.WriteLine($"MQTTnet - BenchmarkApp.{TargetFrameworkInfoProvider.TargetFramework}"); + Console.WriteLine("1 = MessageProcessingBenchmark"); + Console.WriteLine("2 = SerializerBenchmark"); + + var pressedKey = Console.ReadKey(true); + switch (pressedKey.KeyChar) + { + case '1': + BenchmarkRunner.Run(); + break; + case '2': + BenchmarkRunner.Run(); + break; + default: + break; + } + + Console.ReadLine(); + } + } +} diff --git a/Tests/MQTTnet.Benchmarks/Properties/AssemblyInfo.cs b/Tests/MQTTnet.Benchmarks/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..4df83f6 --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/Properties/AssemblyInfo.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("MQTTnet.Benchmarks")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("MQTTnet.Benchmarks")] +[assembly: AssemblyCopyright("Copyright © 2018")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("998d04dd-7cb0-45f5-a393-e2495c16399e")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs new file mode 100644 index 0000000..bcfff1c --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -0,0 +1,74 @@ +using BenchmarkDotNet.Attributes; +using MQTTnet.Client; +using MQTTnet.Packets; +using MQTTnet.Serializer; +using MQTTnet.Internal; +using MQTTnet.Server; +using BenchmarkDotNet.Attributes.Jobs; +using BenchmarkDotNet.Attributes.Exporters; +using System; +using System.Threading; +using System.IO; +using MQTTnet.Core.Internal; + +namespace MQTTnet.Benchmarks +{ + [ClrJob] + [RPlotExporter] + [MemoryDiagnoser] + public class SerializerBenchmark + { + private MqttBasePacket _packet; + private ArraySegment _serializedPacket; + private MqttPacketSerializer _serializer; + + [GlobalSetup] + public void Setup() + { + var message = new MqttApplicationMessageBuilder() + .WithTopic("A") + .Build(); + + _packet = message.ToPublishPacket(); + _serializer = new MqttPacketSerializer(); + _serializedPacket = _serializer.Serialize(_packet); + } + + [Benchmark] + public void Serialize_10000_Messages() + { + for (var i = 0; i < 10000; i++) + { + _serializer.Serialize(_packet); + } + } + + [Benchmark] + public void Deserialize_10000_Messages() + { + for (var i = 0; i < 10000; i++) + { + using (var headerStream = new MemoryStream(Join(_serializedPacket))) + { + var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(headerStream), CancellationToken.None).GetAwaiter().GetResult(); + + using (var bodyStream = new MemoryStream(Join(_serializedPacket), (int)headerStream.Position, header.BodyLength)) + { + _serializer.Deserialize(header, bodyStream); + } + } + } + } + + private static byte[] Join(params ArraySegment[] chunks) + { + var buffer = new MemoryStream(); + foreach (var chunk in chunks) + { + buffer.Write(chunk.Array, chunk.Offset, chunk.Count); + } + + return buffer.ToArray(); + } + } +} diff --git a/Tests/MQTTnet.Benchmarks/packages.config b/Tests/MQTTnet.Benchmarks/packages.config new file mode 100644 index 0000000..bb93014 --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/packages.config @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs b/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs new file mode 100644 index 0000000..77275eb --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs @@ -0,0 +1,34 @@ +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Internal; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class AsyncAutoResetEventTests + { + [TestMethod] + public async Task AsyncAutoResetEvent() + { + var aare = new AsyncAutoResetEvent(); + + var increment = 0; + var globalI = 0; +#pragma warning disable 4014 + Task.Run(async () => +#pragma warning restore 4014 + { + await aare.WaitOneAsync(CancellationToken.None); + globalI += increment; + }); + + await Task.Delay(500); + increment = 1; + aare.Set(); + await Task.Delay(100); + + Assert.AreEqual(1, globalI); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs new file mode 100644 index 0000000..56c7050 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs @@ -0,0 +1,44 @@ +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Internal; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class AsyncLockTests + { + [TestMethod] + public void AsyncLock() + { + const int ThreadsCount = 10; + + var threads = new Task[ThreadsCount]; + var @lock = new AsyncLock(); + var globalI = 0; + for (var i = 0; i < ThreadsCount; i++) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + threads[i] = Task.Run(async () => + { + await @lock.EnterAsync(CancellationToken.None); + try + { + var localI = globalI; + await Task.Delay(10); // Increase the chance for wrong data. + localI++; + globalI = localI; + } + finally + { + @lock.Exit(); + } + }); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + Task.WaitAll(threads); + Assert.AreEqual(ThreadsCount, globalI); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index 053b0ca..696cfa3 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -1,9 +1,9 @@ using System; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Exceptions; -using MQTTnet.Internal; namespace MQTTnet.Core.Tests { @@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests [TestMethod] public async Task TimeoutAfter() { - await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] public async Task TimeoutAfterWithResult() { - await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [TestMethod] public async Task TimeoutAfterCompleteInTime() { - var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500)); + var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); Assert.AreEqual(5, result); } @@ -36,17 +36,17 @@ namespace MQTTnet.Core.Tests { try { - await Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; - }).TimeoutAfter(TimeSpan.FromSeconds(1)); + }, ct), TimeSpan.FromSeconds(1), CancellationToken.None); Assert.Fail(); } - catch (MqttCommunicationException e) + catch (Exception e) { - Assert.IsTrue(e.InnerException is IndexOutOfRangeException); + Assert.IsTrue(e is IndexOutOfRangeException); } } @@ -55,17 +55,18 @@ namespace MQTTnet.Core.Tests { try { - await Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => { var iis = new int[0]; - return iis[1]; - }).TimeoutAfter(TimeSpan.FromSeconds(1)); + iis[1] = 0; + return iis[0]; + }, ct), TimeSpan.FromSeconds(1), CancellationToken.None); Assert.Fail(); } - catch (MqttCommunicationException e) + catch (Exception e) { - Assert.IsTrue(e.InnerException is IndexOutOfRangeException); + Assert.IsTrue(e is IndexOutOfRangeException); } } @@ -73,7 +74,10 @@ namespace MQTTnet.Core.Tests public async Task TimeoutAfterMemoryUsage() { var tasks = Enumerable.Range(0, 100000) - .Select(i => Task.Delay(TimeSpan.FromMilliseconds(1)).TimeoutAfter(TimeSpan.FromMinutes(1))); + .Select(i => + { + return MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); + }); await Task.WhenAll(tasks); AssertIsLess(3_000_000, GC.GetTotalMemory(true)); diff --git a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs index 206942f..f3c4320 100644 --- a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs @@ -1,5 +1,4 @@ using System.Threading; -using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Diagnostics; using MQTTnet.Packets; @@ -18,7 +17,6 @@ namespace MQTTnet.Core.Tests var monitor = new MqttClientKeepAliveMonitor(string.Empty, delegate { timeoutCalledCount++; - return Task.FromResult(0); }, new MqttNetLogger()); Assert.AreEqual(0, timeoutCalledCount); @@ -40,7 +38,6 @@ namespace MQTTnet.Core.Tests var monitor = new MqttClientKeepAliveMonitor(string.Empty, delegate { timeoutCalledCount++; - return Task.FromResult(0); }, new MqttNetLogger()); Assert.AreEqual(0, timeoutCalledCount); diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index 4d28035..72a675b 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -1,6 +1,7 @@ using System.IO; using System.Threading; using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Internal; using MQTTnet.Serializer; namespace MQTTnet.Core.Tests @@ -11,8 +12,7 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttPacketReader_EmptyStream() { - var memStream = new MemoryStream(); - var header = MqttPacketReader.ReadHeaderAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); + var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); Assert.IsNull(header); } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 1111d90..a45736d 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -1,9 +1,9 @@ using System; -using System.Collections.Generic; using System.IO; using System.Text; using System.Threading; using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Serializer; @@ -25,7 +25,7 @@ namespace MQTTnet.Core.Tests CleanSession = true }; - SerializeAndCompare(p, "EB0ABE1RSXNkcAPCAHsAA1hZWgAEVVNFUgAEUEFTUw==", MqttProtocolVersion.V310); + SerializeAndCompare(p, "EB0ABk1RSXNkcAPCAHsAA1hZWgAEVVNFUgAEUEFTUw==", MqttProtocolVersion.V310); } [TestMethod] @@ -403,9 +403,9 @@ namespace MQTTnet.Core.Tests private static void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) { var serializer = new MqttPacketSerializer { ProtocolVersion = protocolVersion }; - var chunks = serializer.Serialize(packet); + var data = serializer.Serialize(packet); - Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(chunks))); + Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(data))); } private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) @@ -416,7 +416,7 @@ namespace MQTTnet.Core.Tests using (var headerStream = new MemoryStream(Join(buffer1))) { - var header = MqttPacketReader.ReadHeaderAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); + var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(headerStream), CancellationToken.None).GetAwaiter().GetResult(); using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.BodyLength)) { @@ -428,7 +428,7 @@ namespace MQTTnet.Core.Tests } } - private static byte[] Join(IEnumerable> chunks) + private static byte[] Join(params ArraySegment[] chunks) { var buffer = new MemoryStream(); foreach (var chunk in chunks) diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index dc2f525..bcd9fb8 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Diagnostics; @@ -58,8 +59,8 @@ namespace MQTTnet.Core.Tests await s.StartAsync(new MqttServerOptions()); var willMessage = new MqttApplicationMessageBuilder().WithTopic("My/last/will").WithAtMostOnceQoS().Build(); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2", willMessage); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2", willMessage); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic("#").Build()); @@ -90,8 +91,8 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); @@ -149,7 +150,7 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; @@ -167,6 +168,40 @@ namespace MQTTnet.Core.Tests Assert.AreEqual(1, receivedMessagesCount); } + [TestMethod] + public async Task MqttServer_RetainedMessagesFlow() + { + var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); + var serverAdapter = new TestMqttServerAdapter(); + var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); + await s.StartAsync(new MqttServerOptions()); + var c1 = await serverAdapter.ConnectTestClient("c1"); + await c1.PublishAsync(retainedMessage); + Thread.Sleep(500); + await c1.DisconnectAsync(); + Thread.Sleep(500); + + var receivedMessages = 0; + var c2 = await serverAdapter.ConnectTestClient("c2"); + c2.ApplicationMessageReceived += (_, e) => + { + receivedMessages++; + }; + + for (var i = 0; i < 5; i++) + { + await c2.UnsubscribeAsync("r"); + await Task.Delay(500); + Assert.AreEqual(i, receivedMessages); + + await c2.SubscribeAsync("r"); + await Task.Delay(500); + Assert.AreEqual(i + 1, receivedMessages); + } + + await c2.DisconnectAsync(); + } + [TestMethod] public async Task MqttServer_NoRetainedMessage() { @@ -179,11 +214,11 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).Build()); await c1.DisconnectAsync(); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); @@ -208,11 +243,11 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); await c1.DisconnectAsync(); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); @@ -237,15 +272,16 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[0]).WithRetainFlag().Build()); await c1.DisconnectAsync(); - - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - await c2.SubscribeAsync(new TopicFilter("retained", MqttQualityOfServiceLevel.AtMostOnce)); + await Task.Delay(200); + await c2.SubscribeAsync(new TopicFilter("retained", MqttQualityOfServiceLevel.AtMostOnce)); await Task.Delay(500); } finally @@ -270,7 +306,7 @@ namespace MQTTnet.Core.Tests await s.StartAsync(options); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); + var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); await c1.DisconnectAsync(); @@ -290,7 +326,7 @@ namespace MQTTnet.Core.Tests var options = new MqttServerOptions { Storage = storage }; await s.StartAsync(options); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build()); @@ -321,8 +357,8 @@ namespace MQTTnet.Core.Tests await s.StartAsync(options); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("test").Build()); var isIntercepted = false; @@ -356,8 +392,8 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c1.ApplicationMessageReceived += (_, e) => { @@ -411,8 +447,8 @@ namespace MQTTnet.Core.Tests { await s.StartAsync(new MqttServerOptions()); - var c1 = await serverAdapter.ConnectTestClient(s, "c1"); - var c2 = await serverAdapter.ConnectTestClient(s, "c2"); + var c1 = await serverAdapter.ConnectTestClient("c1"); + var c2 = await serverAdapter.ConnectTestClient("c2"); c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs index 8b2863a..6f16ab5 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs @@ -1,4 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Adapter; +using MQTTnet.Diagnostics; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Server; @@ -11,12 +13,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); + var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServer(new IMqttServerAdapter[0], new MqttNetLogger())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -32,12 +34,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); + var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServer(new IMqttServerAdapter[0], new MqttNetLogger())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce)); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -53,13 +55,13 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); + var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServer(new IMqttServerAdapter[0], new MqttNetLogger())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("#", MqttQualityOfServiceLevel.AtMostOnce)); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtLeastOnce)); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -75,12 +77,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); + var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServer(new IMqttServerAdapter[0], new MqttNetLogger())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -94,12 +96,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); + var sm = new MqttClientSubscriptionsManager("", new MqttServerOptions(), new MqttServer(new IMqttServerAdapter[0], new MqttNetLogger())); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).Wait(); + sm.Subscribe(sp); var pp = new MqttApplicationMessage { @@ -111,7 +113,7 @@ namespace MQTTnet.Core.Tests var up = new MqttUnsubscribePacket(); up.TopicFilters.Add("A/B/C"); - sm.UnsubscribeAsync(up).Wait(); + sm.Unsubscribe(up); Assert.IsFalse(sm.CheckSubscriptionsAsync(pp).Result.IsSubscribed); } diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index a898f76..837221b 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -21,17 +21,17 @@ namespace MQTTnet.Core.Tests { } - public Task ConnectAsync(TimeSpan timeout) + public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { return Task.FromResult(0); } - public Task DisconnectAsync(TimeSpan timeout) + public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { return Task.FromResult(0); } - public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) + public Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); @@ -43,11 +43,30 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) + public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); - return Task.Run(() => + if (timeout > TimeSpan.Zero) + { + using (var timeoutCts = new CancellationTokenSource(timeout)) + using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) + { + return await Task.Run(() => + { + try + { + return _incomingPackets.Take(cts.Token); + } + catch + { + return null; + } + }, cts.Token); + } + } + + return await Task.Run(() => { try { diff --git a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs index a2fec80..2b24f42 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Client; @@ -11,7 +12,7 @@ namespace MQTTnet.Core.Tests { public event EventHandler ClientAccepted; - public async Task ConnectTestClient(IMqttServer server, string clientId, MqttApplicationMessage willMessage = null) + public async Task ConnectTestClient(string clientId, MqttApplicationMessage willMessage = null) { var adapterA = new TestMqttCommunicationAdapter(); var adapterB = new TestMqttCommunicationAdapter(); @@ -22,8 +23,6 @@ namespace MQTTnet.Core.Tests new TestMqttCommunicationAdapterFactory(adapterA), new MqttNetLogger()); - var connected = WaitForClientToConnect(server, clientId); - FireClientAcceptedEvent(adapterB); var options = new MqttClientOptions @@ -34,29 +33,11 @@ namespace MQTTnet.Core.Tests }; await client.ConnectAsync(options); - await connected; + SpinWait.SpinUntil(() => client.IsConnected); return client; } - private static Task WaitForClientToConnect(IMqttServer s, string clientId) - { - var tcs = new TaskCompletionSource(); - - void Handler(object sender, Server.MqttClientConnectedEventArgs args) - { - if (args.Client.ClientId == clientId) - { - s.ClientConnected -= Handler; - tcs.SetResult(null); - } - } - - s.ClientConnected += Handler; - - return tcs.Task; - } - private void FireClientAcceptedEvent(IMqttChannelAdapter adapter) { ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(adapter)); @@ -71,5 +52,9 @@ namespace MQTTnet.Core.Tests { return Task.FromResult(0); } + + public void Dispose() + { + } } } \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/TopicFilterComparerTests.cs b/Tests/MQTTnet.Core.Tests/TopicFilterComparerTests.cs index be16921..e91e988 100644 --- a/Tests/MQTTnet.Core.Tests/TopicFilterComparerTests.cs +++ b/Tests/MQTTnet.Core.Tests/TopicFilterComparerTests.cs @@ -60,7 +60,33 @@ namespace MQTTnet.Core.Tests CompareAndAssert("A/B/C/D", "#", true); } - private void CompareAndAssert(string topic, string filter, bool expectedResult) + [TestMethod] + public void TopicFilterComparer_MultiLevel_Sport() + { + // Tests from official MQTT spec (4.7.1.2 Multi-level wildcard) + CompareAndAssert("sport/tennis/player1", "sport/tennis/player1/#", true); + CompareAndAssert("sport/tennis/player1/ranking", "sport/tennis/player1/#", true); + CompareAndAssert("sport/tennis/player1/score/wimbledon", "sport/tennis/player1/#", true); + + CompareAndAssert("sport/tennis/player1", "sport/tennis/+", true); + CompareAndAssert("sport/tennis/player2", "sport/tennis/+", true); + CompareAndAssert("sport/tennis/player1/ranking", "sport/tennis/+", false); + + CompareAndAssert("sport", "sport/#", true); + CompareAndAssert("sport", "sport/+", false); + CompareAndAssert("sport/", "sport/+", true); + } + + [TestMethod] + public void TopicFilterComparer_SingleLevel_Finance() + { + // Tests from official MQTT spec (4.7.1.3 Single level wildcard) + CompareAndAssert("/finance", "+/+", true); + CompareAndAssert("/finance", "/+", true); + CompareAndAssert("/finance", "+", false); + } + + private static void CompareAndAssert(string topic, string filter, bool expectedResult) { Assert.AreEqual(expectedResult, MqttTopicFilterComparer.IsMatch(topic, filter)); } diff --git a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj index 0a720ca..29b7253 100644 --- a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj +++ b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj @@ -2,7 +2,7 @@ netcoreapp2.0 - 2.3 + Latest diff --git a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs index 5f2e5ad..878a788 100644 --- a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs @@ -44,7 +44,7 @@ namespace MQTTnet.TestApp.NetCore try { - await client.ConnectAsync(options); + await client.ConnectAsync(options).ConfigureAwait(false); } catch (Exception exception) { @@ -59,7 +59,7 @@ namespace MQTTnet.TestApp.NetCore var sentMessagesCount = 0; while (stopwatch.ElapsedMilliseconds < 1000) { - await client.PublishAsync(messages).ConfigureAwait(false); + client.PublishAsync(messages).GetAwaiter().GetResult(); sentMessagesCount++; } @@ -165,7 +165,7 @@ namespace MQTTnet.TestApp.NetCore Console.WriteLine("Press any key to exit."); Console.ReadLine(); - await mqttServer.StopAsync(); + await mqttServer.StopAsync().ConfigureAwait(false); } catch (Exception e) { diff --git a/Tests/MQTTnet.TestApp.NetCore/Program.cs b/Tests/MQTTnet.TestApp.NetCore/Program.cs index 77b9ea6..4032a17 100644 --- a/Tests/MQTTnet.TestApp.NetCore/Program.cs +++ b/Tests/MQTTnet.TestApp.NetCore/Program.cs @@ -19,6 +19,7 @@ namespace MQTTnet.TestApp.NetCore Console.WriteLine("2 = Start server"); Console.WriteLine("3 = Start performance test"); Console.WriteLine("4 = Start managed client"); + Console.WriteLine("5 = Start public broker test"); var pressedKey = Console.ReadKey(true); if (pressedKey.KeyChar == '1') @@ -37,6 +38,10 @@ namespace MQTTnet.TestApp.NetCore { Task.Run(ManagedClientTest.RunAsync); } + else if (pressedKey.KeyChar == '5') + { + Task.Run(PublicBrokerTest.RunAsync); + } Thread.Sleep(Timeout.Infinite); } diff --git a/Tests/MQTTnet.TestApp.NetCore/PublicBrokerTest.cs b/Tests/MQTTnet.TestApp.NetCore/PublicBrokerTest.cs new file mode 100644 index 0000000..1265a6a --- /dev/null +++ b/Tests/MQTTnet.TestApp.NetCore/PublicBrokerTest.cs @@ -0,0 +1,128 @@ +using MQTTnet.Client; +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Protocol; +using Newtonsoft.Json; + +namespace MQTTnet.TestApp.NetCore +{ + public static class PublicBrokerTest + { + public static async Task RunAsync() + { + //MqttNetGlobalLogger.LogMessagePublished += (s, e) => Console.WriteLine(e.TraceMessage); + + // iot.eclipse.org + await ExecuteTestAsync("iot.eclipse.org TCP", + new MqttClientOptionsBuilder().WithTcpServer("iot.eclipse.org", 1883).Build()); + + await ExecuteTestAsync("iot.eclipse.org WS", + new MqttClientOptionsBuilder().WithWebSocketServer("iot.eclipse.org:80/mqtt").Build()); + + await ExecuteTestAsync("iot.eclipse.org WS TLS", + new MqttClientOptionsBuilder().WithWebSocketServer("iot.eclipse.org:443/mqtt").WithTls().Build()); + + // test.mosquitto.org + await ExecuteTestAsync("test.mosquitto.org TCP", + new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 1883).Build()); + + await ExecuteTestAsync("test.mosquitto.org TCP TLS", + new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883).WithTls().Build()); + + await ExecuteTestAsync("test.mosquitto.org WS", + new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8080/mqtt").Build()); + + await ExecuteTestAsync("test.mosquitto.org WS TLS", + new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8081/mqtt").Build()); + + // broker.hivemq.com + await ExecuteTestAsync("broker.hivemq.com TCP", + new MqttClientOptionsBuilder().WithTcpServer("broker.hivemq.com", 1883).Build()); + + await ExecuteTestAsync("broker.hivemq.com WS", + new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").Build()); + + // mqtt.swifitch.cz + await ExecuteTestAsync("mqtt.swifitch.cz", + new MqttClientOptionsBuilder().WithTcpServer("mqtt.swifitch.cz", 1883).Build()); + + // CloudMQTT + var configFile = Path.Combine("E:\\CloudMqttTestConfig.json"); + if (File.Exists(configFile)) + { + var config = JsonConvert.DeserializeObject(File.ReadAllText(configFile)); + + await ExecuteTestAsync("CloudMQTT TCP", + new MqttClientOptionsBuilder().WithTcpServer(config.Server, config.Port).WithCredentials(config.Username, config.Password).Build()); + + await ExecuteTestAsync("CloudMQTT TCP TLS", + new MqttClientOptionsBuilder().WithTcpServer(config.Server, config.SslPort).WithCredentials(config.Username, config.Password).WithTls().Build()); + + await ExecuteTestAsync("CloudMQTT WS TLS", + new MqttClientOptionsBuilder().WithWebSocketServer(config.Server + ":" + config.SslWebSocketPort + "/mqtt").WithCredentials(config.Username, config.Password).WithTls().Build()); + } + + Write("Finished.", ConsoleColor.White); + Console.ReadLine(); + } + + private static async Task ExecuteTestAsync(string name, IMqttClientOptions options) + { + try + { + Write("Testing '" + name + "'... ", ConsoleColor.Gray); + var factory = new MqttFactory(); + var client = factory.CreateMqttClient(); + var topic = Guid.NewGuid().ToString(); + + MqttApplicationMessage receivedMessage = null; + client.ApplicationMessageReceived += (s, e) => receivedMessage = e.ApplicationMessage; + + await client.ConnectAsync(options); + await client.SubscribeAsync(topic, MqttQualityOfServiceLevel.AtLeastOnce); + await client.PublishAsync(topic, "Hello_World", MqttQualityOfServiceLevel.AtLeastOnce); + + SpinWait.SpinUntil(() => receivedMessage != null, 5000); + + if (receivedMessage?.Topic != topic || receivedMessage?.ConvertPayloadToString() != "Hello_World") + { + throw new Exception("Message invalid."); + } + + await client.UnsubscribeAsync("test"); + await client.DisconnectAsync(); + + Write("[OK]\n", ConsoleColor.Green); + } + catch (Exception e) + { + Write("[FAILED] " + e.Message + "\n", ConsoleColor.Red); + } + } + + private static void Write(string message, ConsoleColor color) + { + Console.ForegroundColor = color; + Console.Write(message); + } + + public class MqttConfig + { + public string Server { get; set; } + + public string Username { get; set; } + + public string Password { get; set; } + + public int Port { get; set; } + + public int SslPort { get; set; } + + public int WebSocketPort { get; set; } + + public int SslWebSocketPort { get; set; } + } + } +} diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs index da0c2f3..3de538a 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs @@ -350,13 +350,10 @@ namespace MQTTnet.TestApp.UniversalWindows payload = Convert.FromBase64String(RpcPayload.Text); } - try { var rpcClient = new MqttRpcClient(_mqttClient); - await rpcClient.EnableAsync(); var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(5), RpcMethod.Text, payload, qos); - await rpcClient.DisableAsync(); RpcResponses.Items.Add(RpcMethod.Text + " >>> " + Encoding.UTF8.GetString(response)); } @@ -364,6 +361,10 @@ namespace MQTTnet.TestApp.UniversalWindows { RpcResponses.Items.Add(RpcMethod.Text + " >>> [TIMEOUT]"); } + catch (Exception exception) + { + RpcResponses.Items.Add(RpcMethod.Text + " >>> [EXCEPTION (" + exception.Message + ")]"); + } } private void ClearRpcResponses(object sender, RoutedEventArgs e)