From caea7910b4b11b36486836fdc6f31a6cb96cd71b Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Sat, 16 Jun 2018 14:30:02 +0200 Subject: [PATCH 01/18] Refactor message processing and async/await usage. --- Build/MQTTnet.Extensions.ManagedClient.nuspec | 4 +- Build/MQTTnet.Extensions.Rpc.nuspec | 4 +- Build/MQTTnet.nuspec | 4 +- Source/MQTTnet/Adapter/IMqttChannelAdapter.cs | 2 +- Source/MQTTnet/Adapter/MqttChannelAdapter.cs | 160 +++++++---- Source/MQTTnet/Client/MqttClient.cs | 38 ++- .../TargetFrameworkInfoProvider.cs | 2 + .../MQTTnet/Implementations/MqttTcpChannel.cs | 2 +- .../Implementations/MqttTcpServerAdapter.cs | 2 +- .../Implementations/MqttTcpServerListener.cs | 7 +- Source/MQTTnet/Internal/TaskExtensions.cs | 4 +- Source/MQTTnet/MQTTnet.csproj | 5 +- Source/MQTTnet/Serializer/ByteReader.cs | 48 ---- Source/MQTTnet/Serializer/ByteWriter.cs | 36 --- Source/MQTTnet/Serializer/MqttPacketReader.cs | 12 +- .../Serializer/MqttPacketSerializer.cs | 222 ++++++++------- Source/MQTTnet/Serializer/MqttPacketWriter.cs | 128 +++++++-- ...ue.cs => MqttClientPendingPacketsQueue.cs} | 8 +- Source/MQTTnet/Server/MqttClientSession.cs | 41 ++- .../Server/MqttClientSessionsManager.cs | 262 ++++++++++-------- .../Server/MqttEnqueuedApplicationMessage.cs | 15 + Source/MQTTnet/Server/MqttServer.cs | 11 +- ...esult.cs => PrepareClientSessionResult.cs} | 2 +- .../ChannelAdapterBenchmark.cs | 2 +- Tests/MQTTnet.Core.Tests/ByteReaderTests.cs | 30 -- Tests/MQTTnet.Core.Tests/ByteWriterTests.cs | 51 ---- Tests/MQTTnet.Core.Tests/ExtensionTests.cs | 12 +- .../MqttPacketReaderTests.cs | 2 +- .../TestMqttCommunicationAdapter.cs | 3 +- .../PerformanceTest.cs | 98 +++---- Tests/MQTTnet.TestApp.NetCore/Program.cs | 3 +- 31 files changed, 586 insertions(+), 634 deletions(-) delete mode 100644 Source/MQTTnet/Serializer/ByteReader.cs delete mode 100644 Source/MQTTnet/Serializer/ByteWriter.cs rename Source/MQTTnet/Server/{MqttClientPendingMessagesQueue.cs => MqttClientPendingPacketsQueue.cs} (93%) create mode 100644 Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs rename Source/MQTTnet/Server/{GetOrCreateClientSessionResult.cs => PrepareClientSessionResult.cs} (76%) delete mode 100644 Tests/MQTTnet.Core.Tests/ByteReaderTests.cs delete mode 100644 Tests/MQTTnet.Core.Tests/ByteWriterTests.cs diff --git a/Build/MQTTnet.Extensions.ManagedClient.nuspec b/Build/MQTTnet.Extensions.ManagedClient.nuspec index e681554..714ed9e 100644 --- a/Build/MQTTnet.Extensions.ManagedClient.nuspec +++ b/Build/MQTTnet.Extensions.ManagedClient.nuspec @@ -48,6 +48,8 @@ - + + + \ No newline at end of file diff --git a/Build/MQTTnet.Extensions.Rpc.nuspec b/Build/MQTTnet.Extensions.Rpc.nuspec index f8667c4..16a51c2 100644 --- a/Build/MQTTnet.Extensions.Rpc.nuspec +++ b/Build/MQTTnet.Extensions.Rpc.nuspec @@ -48,6 +48,8 @@ - + + + \ No newline at end of file diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 25c31fb..5d58437 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -69,6 +69,8 @@ - + + + \ No newline at end of file diff --git a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs index e026e7b..43e5c8c 100644 --- a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs @@ -20,7 +20,7 @@ namespace MQTTnet.Adapter Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken); - Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken); + Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken); Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); } diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 4fc681a..7db9a2a 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -40,52 +40,84 @@ namespace MQTTnet.Adapter public event EventHandler ReadingPacketStarted; public event EventHandler ReadingPacketCompleted; - public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + public async Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); - _logger.Verbose("Connecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => - Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); + try + { + _logger.Verbose("Connecting [Timeout={0}]", timeout); + + await Internal.TaskExtensions + .TimeoutAfterAsync(ct => _channel.ConnectAsync(ct), timeout, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } } - public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + public async Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); - _logger.Verbose("Disconnecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => - Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, cancellationToken)); + try + { + _logger.Verbose("Disconnecting [Timeout={0}]", timeout); + + await Internal.TaskExtensions + .TimeoutAfterAsync(ct => _channel.DisconnectAsync(), timeout, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } } - public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken) + public async Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - return ExecuteAndWrapExceptionAsync(() => + try { - _logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); + _logger.Verbose("TX >>> {0}", packet); var packetData = PacketSerializer.Serialize(packet); - return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( - packetData.Array, - packetData.Offset, - packetData.Count, - ct), timeout, cancellationToken); - }); + await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false); + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } } public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); - MqttBasePacket packet = null; - await ExecuteAndWrapExceptionAsync(async () => + try { ReceivedMqttPacket receivedMqttPacket; if (timeout > TimeSpan.Zero) { - receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); + receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfterAsync(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); } else { @@ -94,19 +126,30 @@ namespace MQTTnet.Adapter if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) { - return; + return null; } - packet = PacketSerializer.Deserialize(receivedMqttPacket); + var packet = PacketSerializer.Deserialize(receivedMqttPacket); if (packet == null) { throw new MqttProtocolViolationException("Received malformed packet."); } _logger.Verbose("RX <<< {0}", packet); - }).ConfigureAwait(false); + + return packet; + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } - return packet; + return null; } private async Task ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) @@ -138,7 +181,9 @@ namespace MQTTnet.Adapter chunkSize = bytesLeft; } - var readBytes = await channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken) .ConfigureAwait(false); + // async/await is not used to avoid the overhead of context switches. We assume that the reamining data + // has been sent from the sender directly after the initial bytes. + var readBytes = channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken).GetAwaiter().GetResult(); if (readBytes <= 0) { ExceptionHelper.ThrowGracefulSocketClose(); @@ -155,42 +200,6 @@ namespace MQTTnet.Adapter } } - private static async Task ExecuteAndWrapExceptionAsync(Func action) - { - try - { - await action().ConfigureAwait(false); - } - catch (Exception exception) - { - if (exception is TaskCanceledException || - exception is OperationCanceledException || - exception is MqttCommunicationTimedOutException || - exception is MqttCommunicationException) - { - throw; - } - - if (exception is IOException && exception.InnerException is SocketException socketException) - { - if (socketException.SocketErrorCode == SocketError.ConnectionAborted) - { - throw new OperationCanceledException(); - } - } - - if (exception is COMException comException) - { - if ((uint)comException.HResult == ErrorOperationAborted) - { - throw new OperationCanceledException(); - } - } - - throw new MqttCommunicationException(exception); - } - } - public void Dispose() { _isDisposed = true; @@ -205,5 +214,34 @@ namespace MQTTnet.Adapter throw new ObjectDisposedException(nameof(MqttChannelAdapter)); } } + + private static bool IsWrappedException(Exception exception) + { + return exception is TaskCanceledException || + exception is OperationCanceledException || + exception is MqttCommunicationTimedOutException || + exception is MqttCommunicationException; + } + + private static void WrapException(Exception exception) + { + if (exception is IOException && exception.InnerException is SocketException socketException) + { + if (socketException.SocketErrorCode == SocketError.ConnectionAborted) + { + throw new OperationCanceledException(); + } + } + + if (exception is COMException comException) + { + if ((uint)comException.HResult == ErrorOperationAborted) + { + throw new OperationCanceledException(); + } + } + + throw new MqttCommunicationException(exception); + } } } diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 79a0ecf..e9362f6 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -17,7 +17,7 @@ namespace MQTTnet.Client { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly Stopwatch _sendTracker = new Stopwatch(); - private readonly SemaphoreSlim _disconnectLock = new SemaphoreSlim(1, 1); + private readonly object _disconnectLock = new object(); private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly IMqttClientAdapterFactory _adapterFactory; @@ -215,7 +215,7 @@ namespace MQTTnet.Client private async Task DisconnectInternalAsync(Task sender, Exception exception) { - await InitiateDisconnectAsync().ConfigureAwait(false); + InitiateDisconnect(); var clientWasConnected = IsConnected; IsConnected = false; @@ -249,25 +249,23 @@ namespace MQTTnet.Client } } - private async Task InitiateDisconnectAsync() + private void InitiateDisconnect() { - await _disconnectLock.WaitAsync().ConfigureAwait(false); - try + lock (_disconnectLock) { - if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested) + try { - return; - } + if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested) + { + return; + } - _cancellationTokenSource.Cancel(false); - } - catch (Exception adapterException) - { - _logger.Warning(adapterException, "Error while initiating disconnect."); - } - finally - { - _disconnectLock.Release(); + _cancellationTokenSource.Cancel(false); + } + catch (Exception adapterException) + { + _logger.Warning(adapterException, "Error while initiating disconnect."); + } } } @@ -279,7 +277,7 @@ namespace MQTTnet.Client } _sendTracker.Restart(); - return _adapter.SendPacketAsync(_options.CommunicationTimeout, packet, cancellationToken); + return _adapter.SendPacketAsync(packet, cancellationToken); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket @@ -300,8 +298,8 @@ namespace MQTTnet.Client var packetAwaiter = _packetDispatcher.AddPacketAwaiter(identifier); try { - await _adapter.SendPacketAsync(_options.CommunicationTimeout, requestPacket, cancellationToken).ConfigureAwait(false); - var respone = await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); + await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false); + var respone = await Internal.TaskExtensions.TimeoutAfterAsync(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); return (TResponsePacket)respone; } diff --git a/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs b/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs index 5258276..efbf08b 100644 --- a/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs +++ b/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs @@ -10,6 +10,8 @@ return "net452"; #elif NET461 return "net461"; +#elif NET472 + return "net472"; #elif NETSTANDARD1_3 return "netstandard1.3"; #elif NETSTANDARD2_0 diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 20829be..c48a06f 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -1,4 +1,4 @@ -#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 +#if !WINDOWS_UWP using System; using System.Net.Security; using System.Net.Sockets; diff --git a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs index 8783e9c..eeefb34 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs @@ -1,4 +1,4 @@ -#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 +#if !WINDOWS_UWP using System; using System.Collections.Generic; using System.Net.Sockets; diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index fccd77f..b77d9b3 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -1,4 +1,4 @@ -#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 +#if !WINDOWS_UWP using System; using System.Net; using System.Net.Security; @@ -76,7 +76,8 @@ namespace MQTTnet.Implementations await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, SslProtocols.Tls12, false).ConfigureAwait(false); } - _logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {_addressFamily}'."); + var protocol = _addressFamily == AddressFamily.InterNetwork ? "ipv4" : "ipv6"; + _logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {protocol}'."); var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger); ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); @@ -102,7 +103,7 @@ namespace MQTTnet.Implementations { _socket?.Dispose(); -#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461 +#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461 || NET472 _tlsCertificate?.Dispose(); #endif } diff --git a/Source/MQTTnet/Internal/TaskExtensions.cs b/Source/MQTTnet/Internal/TaskExtensions.cs index 288ac0b..1356d97 100644 --- a/Source/MQTTnet/Internal/TaskExtensions.cs +++ b/Source/MQTTnet/Internal/TaskExtensions.cs @@ -7,7 +7,7 @@ namespace MQTTnet.Internal { public static class TaskExtensions { - public static async Task TimeoutAfter(Func action, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task TimeoutAfterAsync(Func action, TimeSpan timeout, CancellationToken cancellationToken) { if (action == null) throw new ArgumentNullException(nameof(action)); @@ -31,7 +31,7 @@ namespace MQTTnet.Internal } } - public static async Task TimeoutAfter(Func> action, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task TimeoutAfterAsync(Func> action, TimeSpan timeout, CancellationToken cancellationToken) { if (action == null) throw new ArgumentNullException(nameof(action)); diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index 8364378..4afd09d 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -1,7 +1,7 @@  - netstandard1.3;netstandard2.0;net452;net461;uap10.0 + netstandard1.3;netstandard2.0;net452;uap10.0 netstandard1.3;netstandard2.0 MQTTnet MQTTnet @@ -62,7 +62,4 @@ - - - \ No newline at end of file diff --git a/Source/MQTTnet/Serializer/ByteReader.cs b/Source/MQTTnet/Serializer/ByteReader.cs deleted file mode 100644 index 461bba3..0000000 --- a/Source/MQTTnet/Serializer/ByteReader.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System; - -namespace MQTTnet.Serializer -{ - public class ByteReader - { - private readonly int _source; - private int _index; - - public ByteReader(int source) - { - _source = source; - } - - public bool Read() - { - if (_index >= 8) - { - throw new InvalidOperationException("End of byte reached."); - } - - var result = ((1 << _index) & _source) > 0; - _index++; - return result; - } - - public int Read(int count) - { - if (_index + count > 8) - { - throw new InvalidOperationException("End of byte will be reached."); - } - - var result = 0; - for (var i = 0; i < count; i++) - { - if (((1 << _index) & _source) > 0) - { - result |= 1 << i; - } - - _index++; - } - - return result; - } - } -} diff --git a/Source/MQTTnet/Serializer/ByteWriter.cs b/Source/MQTTnet/Serializer/ByteWriter.cs deleted file mode 100644 index 9ae2156..0000000 --- a/Source/MQTTnet/Serializer/ByteWriter.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System; - -namespace MQTTnet.Serializer -{ - public class ByteWriter - { - private int _index; - private int _byte; - - public byte Value => (byte)_byte; - - public void Write(int @byte, int count) - { - for (var i = 0; i < count; i++) - { - var value = ((1 << i) & @byte) > 0; - Write(value); - } - } - - public void Write(bool bit) - { - if (_index >= 8) - { - throw new InvalidOperationException("End of the byte reached."); - } - - if (bit) - { - _byte |= 1 << _index; - } - - _index++; - } - } -} diff --git a/Source/MQTTnet/Serializer/MqttPacketReader.cs b/Source/MQTTnet/Serializer/MqttPacketReader.cs index 7ed918f..d50fae5 100644 --- a/Source/MQTTnet/Serializer/MqttPacketReader.cs +++ b/Source/MQTTnet/Serializer/MqttPacketReader.cs @@ -12,6 +12,8 @@ namespace MQTTnet.Serializer { // The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. // So in all cases at least 2 bytes must be read for a complete MQTT packet. + // async/await is used here because the next packet is received in a couple of minutes so the performance + // impact is acceptable according to a useless waiting thread. var buffer = new byte[2]; var totalBytesRead = 0; @@ -37,11 +39,11 @@ namespace MQTTnet.Serializer return new MqttFixedHeader(buffer[0], 0); } - var bodyLength = await ReadBodyLengthAsync(channel, buffer[1], cancellationToken).ConfigureAwait(false); + var bodyLength = ReadBodyLength(channel, buffer[1], cancellationToken); return new MqttFixedHeader(buffer[0], bodyLength); } - private static async Task ReadBodyLengthAsync(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) + private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. var multiplier = 128; @@ -50,7 +52,11 @@ namespace MQTTnet.Serializer while ((encodedByte & 128) != 0) { - encodedByte = await ReadByteAsync(channel, cancellationToken).ConfigureAwait(false); + // Here the async/await pattern is not used becuase the overhead of context switches + // is too big for reading 1 byte in a row. We expect that the remaining data was sent + // directly after the initial bytes. If the client disconnects just in this moment we + // will get an exception anyway. + encodedByte = ReadByteAsync(channel, cancellationToken).GetAwaiter().GetResult(); value += (byte)(encodedByte & 127) * multiplier; if (multiplier > 128 * 128 * 128) diff --git a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs index 2acbfbb..811acf7 100644 --- a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs +++ b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs @@ -2,7 +2,6 @@ using MQTTnet.Packets; using MQTTnet.Protocol; using System; -using System.IO; using System.Linq; using MQTTnet.Adapter; @@ -18,57 +17,46 @@ namespace MQTTnet.Serializer { if (packet == null) throw new ArgumentNullException(nameof(packet)); - using (var stream = new MemoryStream(128)) - { - // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) - stream.Seek(5, SeekOrigin.Begin); + var packetWriter = new MqttPacketWriter(); - var fixedHeader = SerializePacket(packet, stream); - var remainingLength = (int)stream.Length - 5; + // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) + packetWriter.Seek(5); - var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength); + var fixedHeader = SerializePacket(packet, packetWriter); + var remainingLength = packetWriter.Length - 5; - var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; - var headerOffset = 5 - headerSize; + var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength); - // Position cursor on correct offset on beginining of array (has leading 0x0) - stream.Seek(headerOffset, SeekOrigin.Begin); - stream.WriteByte(fixedHeader); - stream.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); + var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; + var headerOffset = 5 - headerSize; -#if NET461 || NET452 || NETSTANDARD2_0 - var buffer = stream.GetBuffer(); - return new ArraySegment(buffer, headerOffset, (int)stream.Length - headerOffset); -#else - if (stream.TryGetBuffer(out var segment)) - { - return new ArraySegment(segment.Array, headerOffset, segment.Count - headerOffset); - } + // Position cursor on correct offset on beginining of array (has leading 0x0) + packetWriter.Seek(headerOffset); + packetWriter.Write(fixedHeader); + packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); - var buffer = stream.ToArray(); - return new ArraySegment(buffer, headerOffset, buffer.Length - headerOffset); -#endif - } + var buffer = packetWriter.GetBuffer(); + return new ArraySegment(buffer, headerOffset, packetWriter.Length - headerOffset); } - private byte SerializePacket(MqttBasePacket packet, Stream stream) + private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) { switch (packet) { - case MqttConnectPacket connectPacket: return Serialize(connectPacket, stream); - case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, stream); + case MqttConnectPacket connectPacket: return Serialize(connectPacket, packetWriter); + case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, packetWriter); case MqttDisconnectPacket _: return SerializeEmptyPacket(MqttControlPacketType.Disconnect); case MqttPingReqPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingReq); case MqttPingRespPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingResp); - case MqttPublishPacket publishPacket: return Serialize(publishPacket, stream); - case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, stream); - case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, stream); - case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, stream); - case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, stream); - case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, stream); - case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, stream); - case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, stream); - case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, stream); + case MqttPublishPacket publishPacket: return Serialize(publishPacket, packetWriter); + case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, packetWriter); + case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, packetWriter); + case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, packetWriter); + case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, packetWriter); + case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, packetWriter); + case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, packetWriter); + case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, packetWriter); + case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, packetWriter); default: throw new MqttProtocolViolationException("Packet type invalid."); } } @@ -195,10 +183,9 @@ namespace MQTTnet.Serializer var body = receivedMqttPacket.Body; ThrowIfBodyIsEmpty(body); - var fixedHeader = new ByteReader(receivedMqttPacket.FixedHeader); - var retain = fixedHeader.Read(); - var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2); - var dup = fixedHeader.Read(); + var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; + var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); + var dup = (receivedMqttPacket.FixedHeader & 0x3) > 0; var topic = body.ReadStringWithLengthPrefix(); @@ -253,8 +240,8 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException($"Protocol name ({protocolName}) is not supported."); } - var connectFlags = new ByteReader(body.ReadByte()); - if (connectFlags.Read()) + var connectFlags = body.ReadByte(); + if ((connectFlags & 0x1) > 0) { throw new MqttProtocolViolationException("The first bit of the Connect Flags must be set to 0."); } @@ -262,14 +249,14 @@ namespace MQTTnet.Serializer var packet = new MqttConnectPacket { ProtocolVersion = protocolVersion, - CleanSession = connectFlags.Read() + CleanSession = (connectFlags & 0x2) > 0 }; - var willFlag = connectFlags.Read(); - var willQoS = connectFlags.Read(2); - var willRetain = connectFlags.Read(); - var passwordFlag = connectFlags.Read(); - var usernameFlag = connectFlags.Read(); + var willFlag = (connectFlags & 0x4) > 0; + var willQoS = (connectFlags & 0x18) >> 3; + var willRetain = (connectFlags & 0x20) > 0; + var passwordFlag = (connectFlags & 0x40) > 0; + var usernameFlag = (connectFlags & 0x80) > 0; packet.KeepAlivePeriod = body.ReadUInt16(); packet.ClientId = body.ReadStringWithLengthPrefix(); @@ -322,11 +309,11 @@ namespace MQTTnet.Serializer var packet = new MqttConnAckPacket(); - var firstByteReader = new ByteReader(body.ReadByte()); - + var acknowledgeFlags = body.ReadByte(); + if (ProtocolVersion == MqttProtocolVersion.V311) { - packet.IsSessionPresent = firstByteReader.Read(); + packet.IsSessionPresent = (acknowledgeFlags & 0x1) > 0; } packet.ConnectReturnCode = (MqttConnectReturnCode)body.ReadByte(); @@ -344,119 +331,129 @@ namespace MQTTnet.Serializer } } + // ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local private static void ValidatePublishPacket(MqttPublishPacket packet) { - if (packet == null) throw new ArgumentNullException(nameof(packet)); - if (packet.QualityOfServiceLevel == 0 && packet.Dup) { throw new MqttProtocolViolationException("Dup flag must be false for QoS 0 packets [MQTT-3.3.1-2]."); } } - private byte Serialize(MqttConnectPacket packet, Stream stream) + private byte Serialize(MqttConnectPacket packet, MqttPacketWriter packetWriter) { ValidateConnectPacket(packet); // Write variable header if (ProtocolVersion == MqttProtocolVersion.V311) { - stream.WriteWithLengthPrefix("MQTT"); - stream.WriteByte(4); // 3.1.2.2 Protocol Level 4 + packetWriter.WriteWithLengthPrefix("MQTT"); + packetWriter.Write(4); // 3.1.2.2 Protocol Level 4 } else { - stream.WriteWithLengthPrefix("MQIsdp"); - stream.WriteByte(3); // Protocol Level 3 + packetWriter.WriteWithLengthPrefix("MQIsdp"); + packetWriter.Write(3); // Protocol Level 3 } - var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags - connectFlags.Write(false); // Reserved - connectFlags.Write(packet.CleanSession); - connectFlags.Write(packet.WillMessage != null); - - if (packet.WillMessage != null) + byte connectFlags = 0x0; + if (packet.CleanSession) { - connectFlags.Write((int)packet.WillMessage.QualityOfServiceLevel, 2); - connectFlags.Write(packet.WillMessage.Retain); + connectFlags |= 0x2; } - else + + if (packet.WillMessage != null) { - connectFlags.Write(0, 2); - connectFlags.Write(false); - } + connectFlags |= 0x4; + connectFlags |= (byte)((byte)packet.WillMessage.QualityOfServiceLevel << 3); + if (packet.WillMessage.Retain) + { + connectFlags |= 0x20; + } + } + if (packet.Password != null && packet.Username == null) { throw new MqttProtocolViolationException("If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22]."); } - connectFlags.Write(packet.Password != null); - connectFlags.Write(packet.Username != null); + if (packet.Password != null) + { + connectFlags |= 0x40; + } - stream.Write(connectFlags); - stream.Write(packet.KeepAlivePeriod); - stream.WriteWithLengthPrefix(packet.ClientId); + if (packet.Username != null) + { + connectFlags |= 0x80; + } + + packetWriter.Write(connectFlags); + packetWriter.Write(packet.KeepAlivePeriod); + packetWriter.WriteWithLengthPrefix(packet.ClientId); if (packet.WillMessage != null) { - stream.WriteWithLengthPrefix(packet.WillMessage.Topic); - stream.WriteWithLengthPrefix(packet.WillMessage.Payload); + packetWriter.WriteWithLengthPrefix(packet.WillMessage.Topic); + packetWriter.WriteWithLengthPrefix(packet.WillMessage.Payload); } if (packet.Username != null) { - stream.WriteWithLengthPrefix(packet.Username); + packetWriter.WriteWithLengthPrefix(packet.Username); } if (packet.Password != null) { - stream.WriteWithLengthPrefix(packet.Password); + packetWriter.WriteWithLengthPrefix(packet.Password); } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - private byte Serialize(MqttConnAckPacket packet, Stream stream) + private byte Serialize(MqttConnAckPacket packet, MqttPacketWriter packetWriter) { if (ProtocolVersion == MqttProtocolVersion.V310) { - stream.WriteByte(0); + packetWriter.Write(0); } else if (ProtocolVersion == MqttProtocolVersion.V311) { - var connectAcknowledgeFlags = new ByteWriter(); - connectAcknowledgeFlags.Write(packet.IsSessionPresent); - - stream.Write(connectAcknowledgeFlags); + byte connectAcknowledgeFlags = 0x0; + if (packet.IsSessionPresent) + { + connectAcknowledgeFlags |= 0x1; + } + + packetWriter.Write(connectAcknowledgeFlags); } else { throw new MqttProtocolViolationException("Protocol version not supported."); } - stream.WriteByte((byte)packet.ConnectReturnCode); + packetWriter.Write((byte)packet.ConnectReturnCode); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); } - private static byte Serialize(MqttPubRelPacket packet, Stream stream) + private static byte Serialize(MqttPubRelPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubRel packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } - private static byte Serialize(MqttPublishPacket packet, Stream stream) + private static byte Serialize(MqttPublishPacket packet, MqttPacketWriter packetWriter) { ValidatePublishPacket(packet); - stream.WriteWithLengthPrefix(packet.Topic); + packetWriter.WriteWithLengthPrefix(packet.Topic); if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { @@ -465,7 +462,7 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Publish packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); } else { @@ -477,7 +474,7 @@ namespace MQTTnet.Serializer if (packet.Payload?.Length > 0) { - stream.Write(packet.Payload, 0, packet.Payload.Length); + packetWriter.Write(packet.Payload, 0, packet.Payload.Length); } byte fixedHeader = 0; @@ -497,43 +494,43 @@ namespace MQTTnet.Serializer return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); } - private static byte Serialize(MqttPubAckPacket packet, Stream stream) + private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubAck packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } - private static byte Serialize(MqttPubRecPacket packet, Stream stream) + private static byte Serialize(MqttPubRecPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubRec packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } - private static byte Serialize(MqttPubCompPacket packet, Stream stream) + private static byte Serialize(MqttPubCompPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubComp packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } - private static byte Serialize(MqttSubscribePacket packet, Stream stream) + private static byte Serialize(MqttSubscribePacket packet, MqttPacketWriter packetWriter) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); @@ -542,41 +539,41 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Subscribe packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Count > 0) { foreach (var topicFilter in packet.TopicFilters) { - stream.WriteWithLengthPrefix(topicFilter.Topic); - stream.WriteByte((byte)topicFilter.QualityOfServiceLevel); + packetWriter.WriteWithLengthPrefix(topicFilter.Topic); + packetWriter.Write((byte)topicFilter.QualityOfServiceLevel); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); } - private static byte Serialize(MqttSubAckPacket packet, Stream stream) + private static byte Serialize(MqttSubAckPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("SubAck packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); if (packet.SubscribeReturnCodes?.Any() == true) { foreach (var packetSubscribeReturnCode in packet.SubscribeReturnCodes) { - stream.WriteByte((byte)packetSubscribeReturnCode); + packetWriter.Write((byte)packetSubscribeReturnCode); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); } - private static byte Serialize(MqttUnsubscribePacket packet, Stream stream) + private static byte Serialize(MqttUnsubscribePacket packet, MqttPacketWriter packetWriter) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); @@ -585,27 +582,27 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Unsubscribe packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Any() == true) { foreach (var topicFilter in packet.TopicFilters) { - stream.WriteWithLengthPrefix(topicFilter); + packetWriter.WriteWithLengthPrefix(topicFilter); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static byte Serialize(MqttUnsubAckPacket packet, Stream stream) + private static byte Serialize(MqttUnsubAckPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("UnsubAck packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } @@ -614,6 +611,7 @@ namespace MQTTnet.Serializer return MqttPacketWriter.BuildFixedHeader(type); } + // ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local private static void ThrowIfBodyIsEmpty(MqttPacketBodyReader body) { if (body == null || body.Length == 0) diff --git a/Source/MQTTnet/Serializer/MqttPacketWriter.cs b/Source/MQTTnet/Serializer/MqttPacketWriter.cs index 2cc9d7b..c0c49fc 100644 --- a/Source/MQTTnet/Serializer/MqttPacketWriter.cs +++ b/Source/MQTTnet/Serializer/MqttPacketWriter.cs @@ -1,12 +1,23 @@ using System; -using System.IO; using System.Text; using MQTTnet.Protocol; namespace MQTTnet.Serializer { - public static class MqttPacketWriter + /// + /// This is a custom implementation of a memory stream which provides only MQTTnet relevant features. + /// The goal is to avoid lots of argument checks like in the original stream. The growth rule is the + /// same as for the original MemoryStream in .net. Also this implementation allows accessing the internal + /// buffer for all platforms and .net framework versions (which is not available at the regular MemoryStream). + /// + public class MqttPacketWriter { + private byte[] _buffer = new byte[128]; + + private int _position; + + public int Length { get; private set; } + public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0) { var fixedHeader = (int)packetType << 4; @@ -14,33 +25,6 @@ namespace MQTTnet.Serializer return (byte)fixedHeader; } - public static void Write(this Stream stream, ushort value) - { - var buffer = BitConverter.GetBytes(value); - stream.WriteByte(buffer[1]); - stream.WriteByte(buffer[0]); - } - - public static void Write(this Stream stream, ByteWriter value) - { - if (value == null) throw new ArgumentNullException(nameof(value)); - - stream.WriteByte(value.Value); - } - - public static void WriteWithLengthPrefix(this Stream stream, string value) - { - stream.WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); - } - - public static void WriteWithLengthPrefix(this Stream stream, byte[] value) - { - var length = (ushort)value.Length; - - stream.Write(length); - stream.Write(value, 0, length); - } - public static ArraySegment EncodeRemainingLength(int length) { // write the encoded remaining length right aligned on the 4 byte buffer @@ -69,5 +53,91 @@ namespace MQTTnet.Serializer return new ArraySegment(buffer, 0, bufferOffset); } + + public void WriteWithLengthPrefix(string value) + { + WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); + } + + public void WriteWithLengthPrefix(byte[] value) + { + EnsureAdditionalCapacity(value.Length + 2); + + Write((ushort)value.Length); + Write(value, 0, value.Length); + } + + public void Write(byte @byte) + { + EnsureAdditionalCapacity(1); + + _buffer[_position] = @byte; + IncreasePostition(1); + } + + public void Write(ushort value) + { + EnsureAdditionalCapacity(2); + + _buffer[_position] = (byte)(value >> 8); + IncreasePostition(1); + _buffer[_position] = (byte)value; + IncreasePostition(1); + } + + public void Write(byte[] array, int offset, int count) + { + EnsureAdditionalCapacity(count); + + Array.Copy(array, offset, _buffer, _position, count); + IncreasePostition(count); + } + + public void Seek(int offset) + { + EnsureCapacity(offset); + _position = offset; + } + + public byte[] GetBuffer() + { + return _buffer; + } + + private void EnsureAdditionalCapacity(int additionalCapacity) + { + var freeSpace = _buffer.Length - _position; + if (freeSpace >= additionalCapacity) + { + return; + } + + EnsureCapacity(additionalCapacity - freeSpace); + } + + private void EnsureCapacity(int capacity) + { + if (_buffer.Length >= capacity) + { + return; + } + + var newBufferLength = _buffer.Length; + while (newBufferLength < capacity) + { + newBufferLength *= 2; + } + + Array.Resize(ref _buffer, newBufferLength); + } + + private void IncreasePostition(int length) + { + _position += length; + if (_position > Length) + { + Length = _position; + } + } } } diff --git a/Source/MQTTnet/Server/MqttClientPendingMessagesQueue.cs b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs similarity index 93% rename from Source/MQTTnet/Server/MqttClientPendingMessagesQueue.cs rename to Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs index ff43e6f..7a55c0b 100644 --- a/Source/MQTTnet/Server/MqttClientPendingMessagesQueue.cs +++ b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs @@ -11,7 +11,7 @@ using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttClientPendingMessagesQueue : IDisposable + public class MqttClientPendingPacketsQueue : IDisposable { private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); private readonly IMqttServerOptions _options; @@ -20,13 +20,13 @@ namespace MQTTnet.Server private ConcurrentQueue _queue = new ConcurrentQueue(); - public MqttClientPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) + public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _options = options ?? throw new ArgumentNullException(nameof(options)); _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); - _logger = logger.CreateChildLogger(nameof(MqttClientPendingMessagesQueue)); + _logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue)); } public int Count => _queue.Count; @@ -115,7 +115,7 @@ namespace MQTTnet.Server return; } - await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, packet, cancellationToken).ConfigureAwait(false); + await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); } diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 28e97f0..460bea4 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -18,7 +18,7 @@ namespace MQTTnet.Server private readonly MqttRetainedMessagesManager _retainedMessagesManager; private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; - private readonly MqttClientPendingMessagesQueue _pendingMessagesQueue; + private readonly MqttClientPendingPacketsQueue _pendingPacketsQueue; private readonly MqttClientSubscriptionsManager _subscriptionsManager; private readonly MqttClientSessionsManager _sessionsManager; @@ -49,7 +49,7 @@ namespace MQTTnet.Server _keepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, () => Stop(MqttClientDisconnectType.NotClean), _logger); _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server); - _pendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger); + _pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger); } public string ClientId { get; } @@ -60,7 +60,7 @@ namespace MQTTnet.Server status.IsConnected = _adapter != null; status.Endpoint = _adapter?.Endpoint; status.ProtocolVersion = _adapter?.PacketSerializer?.ProtocolVersion; - status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count; + status.PendingApplicationMessagesCount = _pendingPacketsQueue.Count; status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived; status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; } @@ -80,7 +80,7 @@ namespace MQTTnet.Server _wasCleanDisconnect = false; _willMessage = connectPacket.WillMessage; - _pendingMessagesQueue.Start(adapter, _cancellationTokenSource.Token); + _pendingPacketsQueue.Start(adapter, _cancellationTokenSource.Token); _keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token); while (!_cancellationTokenSource.IsCancellationRequested) @@ -149,13 +149,10 @@ namespace MQTTnet.Server if (_willMessage != null && !_wasCleanDisconnect) { - _sessionsManager.StartDispatchApplicationMessage(this, _willMessage); + _sessionsManager.EnqueueApplicationMessage(this, _willMessage); } _willMessage = null; - - ////_pendingMessagesQueue.WaitForCompletion(); - ////_keepAliveMonitor.WaitForCompletion(); } finally { @@ -196,7 +193,7 @@ namespace MQTTnet.Server } } - _pendingMessagesQueue.Enqueue(publishPacket); + _pendingPacketsQueue.Enqueue(publishPacket); } public Task SubscribeAsync(IList topicFilters) @@ -226,12 +223,12 @@ namespace MQTTnet.Server public void ClearPendingApplicationMessages() { - _pendingMessagesQueue.Clear(); + _pendingPacketsQueue.Clear(); } public void Dispose() { - _pendingMessagesQueue?.Dispose(); + _pendingPacketsQueue?.Dispose(); _cancellationTokenSource?.Dispose(); } @@ -245,7 +242,7 @@ namespace MQTTnet.Server if (packet is MqttPingReqPacket) { - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, new MqttPingRespPacket(), cancellationToken); + return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); } if (packet is MqttPubRelPacket pubRelPacket) @@ -260,7 +257,7 @@ namespace MQTTnet.Server PacketIdentifier = pubRecPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, responsePacket, cancellationToken); + return adapter.SendPacketAsync(responsePacket, cancellationToken); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) @@ -308,7 +305,7 @@ namespace MQTTnet.Server private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket); - await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); + await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); if (subscribeResult.CloseConnection) { @@ -322,7 +319,7 @@ namespace MQTTnet.Server private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, unsubscribeResult, cancellationToken); + return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); } private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) @@ -333,7 +330,7 @@ namespace MQTTnet.Server { case MqttQualityOfServiceLevel.AtMostOnce: { - _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); + _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); return Task.FromResult(0); } case MqttQualityOfServiceLevel.AtLeastOnce: @@ -353,25 +350,25 @@ namespace MQTTnet.Server private Task HandleIncomingPublishPacketWithQoS1(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); + _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); + return adapter.SendPacketAsync(response, cancellationToken); } private Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) - _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); + _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); + return adapter.SendPacketAsync(response, cancellationToken); } - private Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) + private static Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); + return adapter.SendPacketAsync(response, cancellationToken); } private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index c52c706..1cb301a 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -6,27 +6,29 @@ using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; -using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttClientSessionsManager : IDisposable + public class MqttClientSessionsManager { + private readonly BlockingCollection _messageQueue = new BlockingCollection(); private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); - private readonly AsyncLock _sessionPreparationLock = new AsyncLock(); + + private readonly CancellationToken _cancellationToken; private readonly MqttRetainedMessagesManager _retainedMessagesManager; private readonly IMqttServerOptions _options; private readonly IMqttNetChildLogger _logger; - public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetChildLogger logger) + public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, CancellationToken cancellationToken, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _logger = logger.CreateChildLogger(nameof(MqttClientSessionsManager)); + _cancellationToken = cancellationToken; _options = options ?? throw new ArgumentNullException(nameof(options)); Server = server ?? throw new ArgumentNullException(nameof(server)); _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); @@ -34,7 +36,129 @@ namespace MQTTnet.Server public MqttServer Server { get; } - public async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) + public void Start() + { + Task.Factory.StartNew(() => ProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); + } + + public Task StopAsync() + { + foreach (var session in _sessions) + { + session.Value.Stop(MqttClientDisconnectType.NotClean); + } + + _sessions.Clear(); + return Task.FromResult(0); + } + + public Task StartSession(IMqttChannelAdapter clientAdapter) + { + return Task.Run(() => RunSession(clientAdapter, _cancellationToken), _cancellationToken); + } + + public Task> GetClientStatusAsync() + { + var result = new List(); + foreach (var session in _sessions) + { + var status = new MqttClientSessionStatus(this, session.Value); + session.Value.FillStatus(status); + + result.Add(status); + } + + return Task.FromResult((IList)result); + } + + public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + { + if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + + _messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken); + } + + public Task SubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); + } + + return session.SubscribeAsync(topicFilters); + } + + public Task UnsubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); + } + + return session.UnsubscribeAsync(topicFilters); + } + + public void DeleteSession(string clientId) + { + _sessions.TryRemove(clientId, out _); + _logger.Verbose("Session for client '{0}' deleted.", clientId); + } + + private void ProcessQueuedApplicationMessages(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + try + { + var enqueuedApplicationMessage = _messageQueue.Take(cancellationToken); + var sender = enqueuedApplicationMessage.Sender; + var applicationMessage = enqueuedApplicationMessage.ApplicationMessage; + + var interceptorContext = InterceptApplicationMessage(sender, applicationMessage); + if (interceptorContext != null) + { + if (interceptorContext.CloseConnection) + { + enqueuedApplicationMessage.Sender.Stop(MqttClientDisconnectType.NotClean); + } + + if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) + { + return; + } + + applicationMessage = interceptorContext.ApplicationMessage; + } + + Server.OnApplicationMessageReceived(sender?.ClientId, applicationMessage); + + if (applicationMessage.Retain) + { + _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult(); + } + + foreach (var clientSession in _sessions.Values) + { + clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage); + } + } + catch (TaskCanceledException) + { + } + catch (Exception exception) + { + _logger.Error(exception, "Unhandled exception while processing queued application message."); + } + } + } + + private async Task RunSession(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) { var clientId = string.Empty; var wasCleanDisconnect = false; @@ -60,7 +184,7 @@ namespace MQTTnet.Server var connectReturnCode = ValidateConnection(connectPacket); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout, + await clientAdapter.SendPacketAsync( new MqttConnAckPacket { ConnectReturnCode = connectReturnCode @@ -70,15 +194,15 @@ namespace MQTTnet.Server return; } - var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); + var result = PrepareClientSession(connectPacket); var clientSession = result.Session; - await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout, + await clientAdapter.SendPacketAsync( new MqttConnAckPacket { ConnectReturnCode = connectReturnCode, IsSessionPresent = result.IsExistingSession - }, + }, cancellationToken).ConfigureAwait(false); Server.OnClientConnected(clientId); @@ -113,73 +237,6 @@ namespace MQTTnet.Server } } - public Task StopAsync() - { - foreach (var session in _sessions) - { - session.Value.Stop(MqttClientDisconnectType.NotClean); - } - - _sessions.Clear(); - return Task.FromResult(0); - } - - public Task> GetClientStatusAsync() - { - var result = new List(); - foreach (var session in _sessions) - { - var status = new MqttClientSessionStatus(this, session.Value); - session.Value.FillStatus(status); - - result.Add(status); - } - - return Task.FromResult((IList)result); - } - - public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) - { - Task.Run(() => DispatchApplicationMessageAsync(senderClientSession, applicationMessage)); - } - - public Task SubscribeAsync(string clientId, IList topicFilters) - { - if (clientId == null) throw new ArgumentNullException(nameof(clientId)); - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - if (!_sessions.TryGetValue(clientId, out var session)) - { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } - - return session.SubscribeAsync(topicFilters); - } - - public Task UnsubscribeAsync(string clientId, IList topicFilters) - { - if (clientId == null) throw new ArgumentNullException(nameof(clientId)); - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - if (!_sessions.TryGetValue(clientId, out var session)) - { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } - - return session.UnsubscribeAsync(topicFilters); - } - - public void DeleteSession(string clientId) - { - _sessions.TryRemove(clientId, out _); - _logger.Verbose("Session for client '{0}' deleted.", clientId); - } - - public void Dispose() - { - _sessionPreparationLock?.Dispose(); - } - private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) { if (_options.ConnectionValidator == null) @@ -197,9 +254,9 @@ namespace MQTTnet.Server return context.ReturnCode; } - private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) + private PrepareClientSessionResult PrepareClientSession(MqttConnectPacket connectPacket) { - using (await _sessionPreparationLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + lock (_sessions) { var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); if (isSessionPresent) @@ -231,60 +288,19 @@ namespace MQTTnet.Server _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } - return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; + return new PrepareClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; } } - private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) { - try - { - var interceptorContext = InterceptApplicationMessage(senderClientSession, applicationMessage); - if (interceptorContext != null) - { - if (interceptorContext.CloseConnection) - { - senderClientSession.Stop(MqttClientDisconnectType.NotClean); - } - - if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) - { - return; - } - - applicationMessage = interceptorContext.ApplicationMessage; - } - - Server.OnApplicationMessageReceived(senderClientSession?.ClientId, applicationMessage); - - if (applicationMessage.Retain) - { - await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false); - } - - foreach (var clientSession in _sessions.Values) - { - clientSession.EnqueueApplicationMessage(senderClientSession, applicationMessage); - } - } - catch (Exception exception) - { - _logger.Error(exception, "Error while processing application message"); - } - } - - private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) - { - var interceptorContext = new MqttApplicationMessageInterceptorContext( - senderClientSession?.ClientId, - applicationMessage); - var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) { - return interceptorContext; + return null; } + var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage); interceptor(interceptorContext); return interceptorContext; } diff --git a/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs new file mode 100644 index 0000000..20ff2fe --- /dev/null +++ b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs @@ -0,0 +1,15 @@ +namespace MQTTnet.Server +{ + public class MqttEnqueuedApplicationMessage + { + public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) + { + Sender = sender; + ApplicationMessage = applicationMessage; + } + + public MqttClientSession Sender { get; } + + public MqttApplicationMessage ApplicationMessage { get; } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index a38dfd2..38e631e 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -65,7 +65,7 @@ namespace MQTTnet.Server if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started."); - _clientSessionsManager.StartDispatchApplicationMessage(null, applicationMessage); + _clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage); return Task.FromResult(0); } @@ -81,7 +81,8 @@ namespace MQTTnet.Server _retainedMessagesManager = new MqttRetainedMessagesManager(Options, _logger); await _retainedMessagesManager.LoadMessagesAsync().ConfigureAwait(false); - _clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _logger); + _clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _cancellationTokenSource.Token, _logger); + _clientSessionsManager.Start(); foreach (var adapter in _adapters) { @@ -118,8 +119,6 @@ namespace MQTTnet.Server } finally { - _clientSessionsManager?.Dispose(); - _cancellationTokenSource = null; _retainedMessagesManager = null; _clientSessionsManager = null; @@ -155,9 +154,7 @@ namespace MQTTnet.Server private void OnClientAccepted(object sender, MqttServerAdapterClientAcceptedEventArgs eventArgs) { - eventArgs.SessionTask = Task.Run( - () => _clientSessionsManager.RunSessionAsync(eventArgs.Client, _cancellationTokenSource.Token), - _cancellationTokenSource.Token); + eventArgs.SessionTask = _clientSessionsManager.StartSession(eventArgs.Client); } } } diff --git a/Source/MQTTnet/Server/GetOrCreateClientSessionResult.cs b/Source/MQTTnet/Server/PrepareClientSessionResult.cs similarity index 76% rename from Source/MQTTnet/Server/GetOrCreateClientSessionResult.cs rename to Source/MQTTnet/Server/PrepareClientSessionResult.cs index 975d237..9a655be 100644 --- a/Source/MQTTnet/Server/GetOrCreateClientSessionResult.cs +++ b/Source/MQTTnet/Server/PrepareClientSessionResult.cs @@ -1,6 +1,6 @@ namespace MQTTnet.Server { - public class GetOrCreateClientSessionResult + public class PrepareClientSessionResult { public bool IsExistingSession { get; set; } diff --git a/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs b/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs index 20b581c..7f068d6 100644 --- a/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs @@ -65,7 +65,7 @@ namespace MQTTnet.Benchmarks for (var i = 0; i < 10000; i++) { - _channelAdapter.SendPacketAsync(TimeSpan.FromSeconds(15), _packet, CancellationToken.None).GetAwaiter().GetResult(); + _channelAdapter.SendPacketAsync(_packet, CancellationToken.None).GetAwaiter().GetResult(); } _stream.Position = 0; diff --git a/Tests/MQTTnet.Core.Tests/ByteReaderTests.cs b/Tests/MQTTnet.Core.Tests/ByteReaderTests.cs deleted file mode 100644 index e2173cb..0000000 --- a/Tests/MQTTnet.Core.Tests/ByteReaderTests.cs +++ /dev/null @@ -1,30 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Serializer; - -namespace MQTTnet.Core.Tests -{ - [TestClass] - public class ByteReaderTests - { - [TestMethod] - public void ByteReader_ReadToEnd() - { - var reader = new ByteReader(85); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - } - - [TestMethod] - public void ByteReader_ReadPartial() - { - var reader = new ByteReader(15); - Assert.AreEqual(3, reader.Read(2)); - } - } -} diff --git a/Tests/MQTTnet.Core.Tests/ByteWriterTests.cs b/Tests/MQTTnet.Core.Tests/ByteWriterTests.cs deleted file mode 100644 index 881df5c..0000000 --- a/Tests/MQTTnet.Core.Tests/ByteWriterTests.cs +++ /dev/null @@ -1,51 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Serializer; - -namespace MQTTnet.Core.Tests -{ - [TestClass] - public class ByteWriterTests - { - [TestMethod] - public void ByteWriter_WriteMultipleAll() - { - var b = new ByteWriter(); - Assert.AreEqual(0, b.Value); - b.Write(3, 2); - Assert.AreEqual(3, b.Value); - } - - [TestMethod] - public void ByteWriter_WriteMultiplePartial() - { - var b = new ByteWriter(); - Assert.AreEqual(0, b.Value); - b.Write(255, 2); - Assert.AreEqual(3, b.Value); - } - - [TestMethod] - public void ByteWriter_WriteTo0xFF() - { - var b = new ByteWriter(); - - Assert.AreEqual(0, b.Value); - b.Write(true); - Assert.AreEqual(1, b.Value); - b.Write(true); - Assert.AreEqual(3, b.Value); - b.Write(true); - Assert.AreEqual(7, b.Value); - b.Write(true); - Assert.AreEqual(15, b.Value); - b.Write(true); - Assert.AreEqual(31, b.Value); - b.Write(true); - Assert.AreEqual(63, b.Value); - b.Write(true); - Assert.AreEqual(127, b.Value); - b.Write(true); - Assert.AreEqual(255, b.Value); - } - } -} diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index 696cfa3..fdb5888 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests [TestMethod] public async Task TimeoutAfter() { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] public async Task TimeoutAfterWithResult() { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [TestMethod] public async Task TimeoutAfterCompleteInTime() { - var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); + var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); Assert.AreEqual(5, result); } @@ -36,7 +36,7 @@ namespace MQTTnet.Core.Tests { try { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -55,7 +55,7 @@ namespace MQTTnet.Core.Tests { try { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -76,7 +76,7 @@ namespace MQTTnet.Core.Tests var tasks = Enumerable.Range(0, 100000) .Select(i => { - return MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); + return MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); }); await Task.WhenAll(tasks); diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index e3e1ccd..f983622 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Core.Tests public class MqttPacketReaderTests { [TestMethod] - [ExpectedException(typeof(MqttCommunicationException))] + [ExpectedException(typeof(MqttCommunicationClosedGracefullyException))] public void MqttPacketReader_EmptyStream() { MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index ffe2a77..9375f07 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Concurrent; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -36,7 +35,7 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken) + public Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); diff --git a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs index e9e4bc4..752feb5 100644 --- a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs @@ -12,16 +12,46 @@ namespace MQTTnet.TestApp.NetCore { public static class PerformanceTest { - public static async Task RunAsync() + public static void Run() { - Console.WriteLine("Press 'c' for concurrent sends. Otherwise in one batch."); - var concurrent = Console.ReadKey(true).KeyChar == 'c'; + try + { + var mqttServer = new MqttFactory().CreateMqttServer(); + mqttServer.StartAsync(new MqttServerOptions()).GetAwaiter().GetResult(); - var server = Task.Run(RunServerAsync); - await Task.Delay(1000); - var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10), concurrent)); + var options = new MqttClientOptions + { + ChannelOptions = new MqttClientTcpOptions + { + Server = "127.0.0.1" + }, + CleanSession = true + }; + + var client = new MqttFactory().CreateMqttClient(); + client.ConnectAsync(options).GetAwaiter().GetResult(); + + var message = CreateMessage(); + var stopwatch = new Stopwatch(); - await Task.WhenAll(server, client).ConfigureAwait(false); + for (var i = 0; i < 10; i++) + { + stopwatch.Restart(); + + var sentMessagesCount = 0; + while (stopwatch.ElapsedMilliseconds < 1000) + { + client.PublishAsync(message).GetAwaiter().GetResult(); + sentMessagesCount++; + } + + Console.WriteLine($"Sending {sentMessagesCount} messages per second. #" + (i + 1)); + } + } + catch (Exception exception) + { + Console.WriteLine(exception); + } } private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval, bool concurrent) @@ -53,29 +83,8 @@ namespace MQTTnet.TestApp.NetCore } var message = CreateMessage(); - var messages = new[] { message }; - var stopwatch = Stopwatch.StartNew(); - var sentMessagesCount = 0; - while (stopwatch.ElapsedMilliseconds < 1000) - { - client.PublishAsync(messages).GetAwaiter().GetResult(); - sentMessagesCount++; - } - - Console.WriteLine($"Sending {sentMessagesCount} messages per second. #1"); - - sentMessagesCount = 0; - stopwatch.Restart(); - while (stopwatch.ElapsedMilliseconds < 1000) - { - await client.PublishAsync(messages).ConfigureAwait(false); - sentMessagesCount++; - } - - Console.WriteLine($"Sending {sentMessagesCount} messages per second. #2"); - var testMessageCount = 10000; for (var i = 0; i < testMessageCount; i++) { @@ -142,38 +151,5 @@ namespace MQTTnet.TestApp.NetCore Interlocked.Increment(ref count); return Task.Run(() => client.PublishAsync(applicationMessage)); } - - private static async Task RunServerAsync() - { - try - { - var mqttServer = new MqttFactory().CreateMqttServer(); - - ////var msgs = 0; - ////var stopwatch = Stopwatch.StartNew(); - ////mqttServer.ApplicationMessageReceived += (sender, args) => - ////{ - //// msgs++; - //// if (stopwatch.ElapsedMilliseconds > 1000) - //// { - //// Console.WriteLine($"received {msgs}"); - //// msgs = 0; - //// stopwatch.Restart(); - //// } - ////}; - await mqttServer.StartAsync(new MqttServerOptions()); - - Console.WriteLine("Press any key to exit."); - Console.ReadLine(); - - await mqttServer.StopAsync().ConfigureAwait(false); - } - catch (Exception e) - { - Console.WriteLine(e); - } - - Console.ReadLine(); - } } } diff --git a/Tests/MQTTnet.TestApp.NetCore/Program.cs b/Tests/MQTTnet.TestApp.NetCore/Program.cs index f8a5d27..25302c7 100644 --- a/Tests/MQTTnet.TestApp.NetCore/Program.cs +++ b/Tests/MQTTnet.TestApp.NetCore/Program.cs @@ -34,7 +34,8 @@ namespace MQTTnet.TestApp.NetCore } else if (pressedKey.KeyChar == '3') { - Task.Run(PerformanceTest.RunAsync); + PerformanceTest.Run(); + return; } else if (pressedKey.KeyChar == '4') { From c758ae89f4bdff2a8b6b91412e7d5cbd80b56755 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sun, 27 May 2018 08:24:15 +0200 Subject: [PATCH 02/18] first iteration --- .../ConnectionBuilderExtensions.cs | 12 ++ .../MqttConnectionContext.cs | 136 ++++++++++++++++++ .../MqttConnectionHandler.cs | 40 ++++++ .../MQTTnet.AspnetCore/ReaderExtensions.cs | 115 +++++++++++++++ .../MQTTnet.AspNetCore.csproj | 9 +- .../ServiceCollectionExtensions.cs | 3 +- Tests/MQTTnet.TestApp.AspNetCore2/Program.cs | 2 + 7 files changed, 313 insertions(+), 4 deletions(-) create mode 100644 Frameworks/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs create mode 100644 Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs create mode 100644 Frameworks/MQTTnet.AspnetCore/MqttConnectionHandler.cs create mode 100644 Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs diff --git a/Frameworks/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Frameworks/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs new file mode 100644 index 0000000..262a333 --- /dev/null +++ b/Frameworks/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -0,0 +1,12 @@ +using Microsoft.AspNetCore.Connections; + +namespace MQTTnet.AspNetCore +{ + public static class ConnectionBuilderExtensions + { + public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder) + { + return builder.UseConnectionHandler(); + } + } +} diff --git a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs new file mode 100644 index 0000000..0bd2b4e --- /dev/null +++ b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -0,0 +1,136 @@ +using Microsoft.AspNetCore.Connections; +using MQTTnet.Adapter; +using MQTTnet.Packets; +using MQTTnet.Serializer; +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + public class MqttConnectionContext : IMqttChannelAdapter + { + public IMqttPacketSerializer PacketSerializer { get; } + public ConnectionContext Connection { get; } + + public string Endpoint => Connection.ConnectionId; + + public MqttConnectionContext( + IMqttPacketSerializer packetSerializer, + ConnectionContext connection) + { + PacketSerializer = packetSerializer; + Connection = connection; + } + + public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + Connection.Transport.Input.Complete(); + Connection.Transport.Output.Complete(); + return Task.CompletedTask; + } + + public void Dispose() + { + } + + public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + var input = Connection.Transport.Input; + + while (!cancellationToken.IsCancellationRequested) + { + ReadResult readResult; + + var readTask = input.ReadAsync(cancellationToken); + if (readTask.IsCompleted) + { + readResult = readTask.Result; + } + else + { + readResult = await readTask; + } + + var buffer = readResult.Buffer; + + var consumed = buffer.Start; + var observed = buffer.Start; + + try + { + if (!buffer.IsEmpty) + { + if (PacketSerializer.TryDeserialize(buffer, out var packet, out consumed, out observed)) + { + return packet; + } + } + else if (readResult.IsCompleted) + { + break; + } + } + finally + { + // The buffer was sliced up to where it was consumed, so we can just advance to the start. + // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data + // before yielding the read again. + input.AdvanceTo(consumed, observed); + } + } + + cancellationToken.ThrowIfCancellationRequested(); + return null; + } + + public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets, CancellationToken cancellationToken) + { + foreach (var packet in packets) + { + await WriteAsync(packet); + } + } + + public async Task WriteAsync(MqttBasePacket packet) + { + var buffer = PacketSerializer.Serialize(packet); + await Connection.Transport.Output.WriteAsync(buffer.AsMemory()); + } + + private int messageId; + public Task PublishAsync(MqttPublishPacket packet) + { + if (!packet.PacketIdentifier.HasValue && packet.QualityOfServiceLevel > MQTTnet.Protocol.MqttQualityOfServiceLevel.AtMostOnce) + { + packet.PacketIdentifier = (ushort)Interlocked.Increment(ref messageId); + } + return WriteAsync(packet); + } + + public Task SubscribeAsync(MqttSubscribePacket packet) + { + if (!packet.PacketIdentifier.HasValue) + { + packet.PacketIdentifier = (ushort)Interlocked.Increment(ref messageId); + } + return WriteAsync(packet); + } + + public Task ConnectAsync(MqttConnectPacket packet) + { + if (string.IsNullOrEmpty(packet.ClientId)) + { + packet.ClientId = Guid.NewGuid().ToString(); + } + return WriteAsync(packet); + } + } +} diff --git a/Frameworks/MQTTnet.AspnetCore/MqttConnectionHandler.cs b/Frameworks/MQTTnet.AspnetCore/MqttConnectionHandler.cs new file mode 100644 index 0000000..49a5c09 --- /dev/null +++ b/Frameworks/MQTTnet.AspnetCore/MqttConnectionHandler.cs @@ -0,0 +1,40 @@ +using Microsoft.AspNetCore.Connections; +using MQTTnet.Adapter; +using MQTTnet.Serializer; +using MQTTnet.Server; +using System; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + public class MqttConnectionHandler : ConnectionHandler, IMqttServerAdapter + { + public event EventHandler ClientAccepted; + + public override async Task OnConnectedAsync(ConnectionContext connection) + { + var serializer = new MqttPacketSerializer(); + using (var adapter = new MqttConnectionContext(serializer, connection)) + { + var args = new MqttServerAdapterClientAcceptedEventArgs(adapter); + ClientAccepted?.Invoke(this, args); + + await args.SessionTask; + } + } + + public Task StartAsync(IMqttServerOptions options) + { + return Task.CompletedTask; + } + + public Task StopAsync() + { + return Task.CompletedTask; + } + + public void Dispose() + { + } + } +} diff --git a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs new file mode 100644 index 0000000..a4e5e02 --- /dev/null +++ b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs @@ -0,0 +1,115 @@ +using System; +using System.Buffers; +using System.IO; +using MQTTnet.Exceptions; +using MQTTnet.Packets; +using MQTTnet.Protocol; +using MQTTnet.Serializer; + +namespace MQTTnet.AspNetCore +{ + public static class ReaderExtensions + { + public static MqttPacketHeader ReadHeader(this ref ReadOnlySequence input) + { + if (input.Length < 2) + { + return null; + } + + var fixedHeader = input.First.Span[0]; + var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); + var bodyLength = ReadBodyLength(ref input); + + return new MqttPacketHeader + { + FixedHeader = fixedHeader, + ControlPacketType = controlPacketType, + BodyLength = bodyLength + }; + } + + private static int ReadBodyLength(ref ReadOnlySequence input) + { + // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. + var multiplier = 1; + var value = 0; + byte encodedByte; + var index = 1; + + var temp = input.Slice(0, Math.Min(5, input.Length)).GetArray(); + + do + { + encodedByte = temp[index]; + index++; + + value += (byte)(encodedByte & 127) * multiplier; + if (multiplier > 128 * 128 * 128) + { + throw new MqttProtocolViolationException($"Remaining length is invalid (Data={string.Join(",", temp.AsSpan(1, index).ToArray())})."); + } + + multiplier *= 128; + } while ((encodedByte & 128) != 0); + + input = input.Slice(index); + + return value; + } + + + + public static byte[] GetArray(this in ReadOnlySequence input) + { + if (input.IsSingleSegment) + { + return input.First.Span.ToArray(); + } + + // Should be rare + return input.ToArray(); + } + + public static bool TryDeserialize(this IMqttPacketSerializer serializer, ref ReadOnlySequence input, out MqttBasePacket packet) + { + packet = null; + var copy = input; + var header = copy.ReadHeader(); + if (header == null || copy.Length < header.BodyLength) + { + return false; + } + + input = copy.Slice(header.BodyLength); + var bodySlice = copy.Slice(0, header.BodyLength); + using (var body = new MemoryStream(bodySlice.GetArray())) + { + packet = serializer.Deserialize(header, body); + return true; + } + } + + public static bool TryDeserialize(this IMqttPacketSerializer serializer, in ReadOnlySequence input, out MqttBasePacket packet, out SequencePosition consumed, out SequencePosition observed) + { + packet = null; + var copy = input; + var header = copy.ReadHeader(); + if (header == null || copy.Length < header.BodyLength) + { + consumed = input.Start; + observed = input.End; + return false; + } + + var bodySlice = copy.Slice(0, header.BodyLength); + using (var body = new MemoryStream(bodySlice.GetArray())) + { + packet = serializer.Deserialize(header, body); + consumed = bodySlice.End; + observed = bodySlice.End; + return true; + } + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 6b85895..661ff16 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -9,6 +9,7 @@ + 7.2 @@ -16,9 +17,11 @@ - - - + + + + + diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 5b8f161..9d2a676 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -30,7 +30,8 @@ namespace MQTTnet.AspNetCore if (options.DefaultEndpointOptions.IsEnabled) { - services.AddSingleton(s => s.GetService()); + services.AddSingleton(); + services.AddSingleton(s => s.GetService()); } return services; diff --git a/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs b/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs index 38b8c12..a716411 100644 --- a/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs +++ b/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs @@ -1,5 +1,6 @@ using Microsoft.AspNetCore; using Microsoft.AspNetCore.Hosting; +using MQTTnet.AspNetCore; namespace MQTTnet.TestApp.AspNetCore2 { @@ -12,6 +13,7 @@ namespace MQTTnet.TestApp.AspNetCore2 private static IWebHost BuildWebHost(string[] args) => WebHost.CreateDefaultBuilder(args) + .UseKestrel(o => o.ListenAnyIP(1883, l => l.UseMqtt())) .UseStartup() .Build(); } From 7a7fab090756d4390caf38c2bc6168edbdbf6d14 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Fri, 1 Jun 2018 21:13:55 +0200 Subject: [PATCH 03/18] fixed rebase --- .../MqttConnectionContext.cs | 8 ++- .../MQTTnet.AspnetCore/ReaderExtensions.cs | 62 +++++++++---------- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs index 0bd2b4e..7683fd0 100644 --- a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -48,7 +48,8 @@ namespace MQTTnet.AspNetCore while (!cancellationToken.IsCancellationRequested) { ReadResult readResult; - + ReadingPacketStarted?.Invoke(this, EventArgs.Empty); + var readTask = input.ReadAsync(cancellationToken); if (readTask.IsCompleted) { @@ -84,6 +85,7 @@ namespace MQTTnet.AspNetCore // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data // before yielding the read again. input.AdvanceTo(consumed, observed); + ReadingPacketCompleted?.Invoke(this, EventArgs.Empty); } } @@ -106,6 +108,10 @@ namespace MQTTnet.AspNetCore } private int messageId; + + public event EventHandler ReadingPacketStarted; + public event EventHandler ReadingPacketCompleted; + public Task PublishAsync(MqttPublishPacket packet) { if (!packet.PacketIdentifier.HasValue && packet.QualityOfServiceLevel > MQTTnet.Protocol.MqttQualityOfServiceLevel.AtMostOnce) diff --git a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs index a4e5e02..718538b 100644 --- a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs +++ b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs @@ -1,6 +1,7 @@ using System; using System.Buffers; using System.IO; +using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -10,37 +11,23 @@ namespace MQTTnet.AspNetCore { public static class ReaderExtensions { - public static MqttPacketHeader ReadHeader(this ref ReadOnlySequence input) - { - if (input.Length < 2) - { - return null; - } - - var fixedHeader = input.First.Span[0]; - var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); - var bodyLength = ReadBodyLength(ref input); - - return new MqttPacketHeader - { - FixedHeader = fixedHeader, - ControlPacketType = controlPacketType, - BodyLength = bodyLength - }; - } - - private static int ReadBodyLength(ref ReadOnlySequence input) + private static bool TryReadBodyLength(ref ReadOnlySequence input, out int result) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. var multiplier = 1; var value = 0; byte encodedByte; var index = 1; + result = 0; var temp = input.Slice(0, Math.Min(5, input.Length)).GetArray(); do { + if (index == temp.Length) + { + return false; + } encodedByte = temp[index]; index++; @@ -55,7 +42,8 @@ namespace MQTTnet.AspNetCore input = input.Slice(index); - return value; + result = value; + return true; } @@ -75,17 +63,22 @@ namespace MQTTnet.AspNetCore { packet = null; var copy = input; - var header = copy.ReadHeader(); - if (header == null || copy.Length < header.BodyLength) + if (copy.Length < 2) + { + return false; + } + + var fixedheader = copy.First.Span[0]; + if (!TryReadBodyLength(ref copy, out var bodyLength)) { return false; } - input = copy.Slice(header.BodyLength); - var bodySlice = copy.Slice(0, header.BodyLength); + input = copy.Slice(bodyLength); + var bodySlice = copy.Slice(0, bodyLength); using (var body = new MemoryStream(bodySlice.GetArray())) { - packet = serializer.Deserialize(header, body); + packet = serializer.Deserialize(new ReceivedMqttPacket(fixedheader, body)); return true; } } @@ -93,19 +86,24 @@ namespace MQTTnet.AspNetCore public static bool TryDeserialize(this IMqttPacketSerializer serializer, in ReadOnlySequence input, out MqttBasePacket packet, out SequencePosition consumed, out SequencePosition observed) { packet = null; + consumed = input.Start; + observed = input.End; var copy = input; - var header = copy.ReadHeader(); - if (header == null || copy.Length < header.BodyLength) + if (copy.Length < 2) + { + return false; + } + + var fixedheader = copy.First.Span[0]; + if (!TryReadBodyLength(ref copy, out var bodyLength)) { - consumed = input.Start; - observed = input.End; return false; } - var bodySlice = copy.Slice(0, header.BodyLength); + var bodySlice = copy.Slice(0, bodyLength); using (var body = new MemoryStream(bodySlice.GetArray())) { - packet = serializer.Deserialize(header, body); + packet = serializer.Deserialize(new ReceivedMqttPacket(fixedheader, body)); consumed = bodySlice.End; observed = bodySlice.End; return true; From d5ac8c7183ff3af9313fc441afe858b3890012d9 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Fri, 1 Jun 2018 21:41:04 +0200 Subject: [PATCH 04/18] use final packages --- Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj | 8 +++----- .../MQTTnet.TestApp.AspNetCore2.csproj | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 661ff16..9754f40 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -17,11 +17,9 @@ - - - - - + + + diff --git a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj index c587ecc..f6ead87 100644 --- a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj +++ b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj @@ -1,7 +1,7 @@  - netcoreapp2.0 + netcoreapp2.1 Latest @@ -10,7 +10,7 @@ - + From 7c2adf636fd26d6cbc4e32b331d93c4c285d7826 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sat, 2 Jun 2018 08:35:41 +0200 Subject: [PATCH 05/18] removed dead code and make new connection context an opt in so kestrel 2.0 still works --- .../MqttConnectionContext.cs | 40 ++----------------- .../MQTTnet.AspnetCore/ReaderExtensions.cs | 25 ------------ .../ServiceCollectionExtensions.cs | 12 +++++- Tests/MQTTnet.TestApp.AspNetCore2/Startup.cs | 8 +++- 4 files changed, 19 insertions(+), 66 deletions(-) diff --git a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs index 7683fd0..88bfcc9 100644 --- a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -97,46 +97,12 @@ namespace MQTTnet.AspNetCore { foreach (var packet in packets) { - await WriteAsync(packet); + var buffer = PacketSerializer.Serialize(packet); + await Connection.Transport.Output.WriteAsync(buffer.AsMemory()); } } - - public async Task WriteAsync(MqttBasePacket packet) - { - var buffer = PacketSerializer.Serialize(packet); - await Connection.Transport.Output.WriteAsync(buffer.AsMemory()); - } - - private int messageId; - + public event EventHandler ReadingPacketStarted; public event EventHandler ReadingPacketCompleted; - - public Task PublishAsync(MqttPublishPacket packet) - { - if (!packet.PacketIdentifier.HasValue && packet.QualityOfServiceLevel > MQTTnet.Protocol.MqttQualityOfServiceLevel.AtMostOnce) - { - packet.PacketIdentifier = (ushort)Interlocked.Increment(ref messageId); - } - return WriteAsync(packet); - } - - public Task SubscribeAsync(MqttSubscribePacket packet) - { - if (!packet.PacketIdentifier.HasValue) - { - packet.PacketIdentifier = (ushort)Interlocked.Increment(ref messageId); - } - return WriteAsync(packet); - } - - public Task ConnectAsync(MqttConnectPacket packet) - { - if (string.IsNullOrEmpty(packet.ClientId)) - { - packet.ClientId = Guid.NewGuid().ToString(); - } - return WriteAsync(packet); - } } } diff --git a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs index 718538b..5b83253 100644 --- a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs +++ b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs @@ -4,7 +4,6 @@ using System.IO; using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Packets; -using MQTTnet.Protocol; using MQTTnet.Serializer; namespace MQTTnet.AspNetCore @@ -59,30 +58,6 @@ namespace MQTTnet.AspNetCore return input.ToArray(); } - public static bool TryDeserialize(this IMqttPacketSerializer serializer, ref ReadOnlySequence input, out MqttBasePacket packet) - { - packet = null; - var copy = input; - if (copy.Length < 2) - { - return false; - } - - var fixedheader = copy.First.Span[0]; - if (!TryReadBodyLength(ref copy, out var bodyLength)) - { - return false; - } - - input = copy.Slice(bodyLength); - var bodySlice = copy.Slice(0, bodyLength); - using (var body = new MemoryStream(bodySlice.GetArray())) - { - packet = serializer.Deserialize(new ReceivedMqttPacket(fixedheader, body)); - return true; - } - } - public static bool TryDeserialize(this IMqttPacketSerializer serializer, in ReadOnlySequence input, out MqttBasePacket packet, out SequencePosition consumed, out SequencePosition observed) { packet = null; diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 9d2a676..fa061eb 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -30,11 +30,19 @@ namespace MQTTnet.AspNetCore if (options.DefaultEndpointOptions.IsEnabled) { - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); + services.AddSingleton(); + services.AddSingleton(s => s.GetService()); } return services; } + + public static IServiceCollection AddMqttConnectionHandler(this IServiceCollection services) + { + services.AddSingleton(); + services.AddSingleton(s => s.GetService()); + + return services; + } } } diff --git a/Tests/MQTTnet.TestApp.AspNetCore2/Startup.cs b/Tests/MQTTnet.TestApp.AspNetCore2/Startup.cs index 591cc87..f52040a 100644 --- a/Tests/MQTTnet.TestApp.AspNetCore2/Startup.cs +++ b/Tests/MQTTnet.TestApp.AspNetCore2/Startup.cs @@ -17,8 +17,12 @@ namespace MQTTnet.TestApp.AspNetCore2 public void ConfigureServices(IServiceCollection services) { - var mqttServerOptions = new MqttServerOptionsBuilder().Build(); - services.AddHostedMqttServer(mqttServerOptions); + var mqttServerOptions = new MqttServerOptionsBuilder() + .WithoutDefaultEndpoint() + .Build(); + services + .AddHostedMqttServer(mqttServerOptions) + .AddMqttConnectionHandler(); } // In class _Startup_ of the ASP.NET Core 2.0 project. From 78c605966883cba51fa594e7cb5500beec0b5897 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sat, 2 Jun 2018 09:31:27 +0200 Subject: [PATCH 06/18] enabled clients based on tcp to benefit from new api and added Benchmark for EndtoEnd with new API --- .../MqttClientConnectionContextFactory.cs | 34 +++++++++ .../Client}/Tcp/BufferExtensions.cs | 2 +- .../Client}/Tcp/DuplexPipe.cs | 2 +- .../Client}/Tcp/SocketAwaitable.cs | 2 +- .../Client}/Tcp/SocketReceiver.cs | 2 +- .../Client}/Tcp/SocketSender.cs | 2 +- .../Client}/Tcp/TcpConnection.cs | 45 +++++++---- .../MqttConnectionContext.cs | 6 ++ .../MQTTnet.AspnetCore/ReaderExtensions.cs | 11 +-- Source/MQTTnet/MqttFactory.cs | 8 ++ .../MQTTnet.Benchmarks.csproj | 2 + ...rocessingMqttConnectionContextBenchmark.cs | 76 +++++++++++++++++++ Tests/MQTTnet.Benchmarks/Program.cs | 4 + Tests/MQTTnet.Benchmarks/TcpPipesBenchmark.cs | 8 +- .../MQTTnet.TestApp.AspNetCore2.csproj | 1 + 15 files changed, 174 insertions(+), 31 deletions(-) create mode 100644 Frameworks/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs rename {Tests/MQTTnet.Benchmarks => Frameworks/MQTTnet.AspnetCore/Client}/Tcp/BufferExtensions.cs (93%) rename {Tests/MQTTnet.Benchmarks => Frameworks/MQTTnet.AspnetCore/Client}/Tcp/DuplexPipe.cs (96%) rename {Tests/MQTTnet.Benchmarks => Frameworks/MQTTnet.AspnetCore/Client}/Tcp/SocketAwaitable.cs (97%) rename {Tests/MQTTnet.Benchmarks => Frameworks/MQTTnet.AspnetCore/Client}/Tcp/SocketReceiver.cs (96%) rename {Tests/MQTTnet.Benchmarks => Frameworks/MQTTnet.AspnetCore/Client}/Tcp/SocketSender.cs (98%) rename {Tests/MQTTnet.Benchmarks => Frameworks/MQTTnet.AspnetCore/Client}/Tcp/TcpConnection.cs (85%) create mode 100644 Tests/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs diff --git a/Frameworks/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs b/Frameworks/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs new file mode 100644 index 0000000..e308ccd --- /dev/null +++ b/Frameworks/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs @@ -0,0 +1,34 @@ +using System; +using System.Net; +using MQTTnet.Adapter; +using MQTTnet.AspNetCore.Client.Tcp; +using MQTTnet.Client; +using MQTTnet.Diagnostics; +using MQTTnet.Serializer; + +namespace MQTTnet.AspNetCore.Client +{ + public class MqttClientConnectionContextFactory : IMqttClientAdapterFactory + { + public IMqttChannelAdapter CreateClientAdapter(IMqttClientOptions options, IMqttNetChildLogger logger) + { + if (options == null) throw new ArgumentNullException(nameof(options)); + + var serializer = new MqttPacketSerializer { ProtocolVersion = options.ProtocolVersion }; + + switch (options.ChannelOptions) + { + case MqttClientTcpOptions tcpOptions: + { + var endpoint = new DnsEndPoint(tcpOptions.Server, tcpOptions.GetPort()); + var tcpConnection = new TcpConnection(endpoint); + return new MqttConnectionContext(serializer, tcpConnection); + } + default: + { + throw new NotSupportedException(); + } + } + } + } +} diff --git a/Tests/MQTTnet.Benchmarks/Tcp/BufferExtensions.cs b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs similarity index 93% rename from Tests/MQTTnet.Benchmarks/Tcp/BufferExtensions.cs rename to Frameworks/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs index 879306c..5911a3a 100644 --- a/Tests/MQTTnet.Benchmarks/Tcp/BufferExtensions.cs +++ b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs @@ -1,7 +1,7 @@ using System; using System.Runtime.InteropServices; -namespace MQTTnet.Benchmarks.Tcp +namespace MQTTnet.AspNetCore.Client.Tcp { public static class BufferExtensions { diff --git a/Tests/MQTTnet.Benchmarks/Tcp/DuplexPipe.cs b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs similarity index 96% rename from Tests/MQTTnet.Benchmarks/Tcp/DuplexPipe.cs rename to Frameworks/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs index f5f3316..e234da5 100644 --- a/Tests/MQTTnet.Benchmarks/Tcp/DuplexPipe.cs +++ b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs @@ -1,6 +1,6 @@ using System.IO.Pipelines; -namespace MQTTnet.Benchmarks.Tcp +namespace MQTTnet.AspNetCore.Client.Tcp { public class DuplexPipe : IDuplexPipe { diff --git a/Tests/MQTTnet.Benchmarks/Tcp/SocketAwaitable.cs b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs similarity index 97% rename from Tests/MQTTnet.Benchmarks/Tcp/SocketAwaitable.cs rename to Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs index 2271bd7..96160d1 100644 --- a/Tests/MQTTnet.Benchmarks/Tcp/SocketAwaitable.cs +++ b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs @@ -6,7 +6,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -namespace MQTTnet.Benchmarks.Tcp +namespace MQTTnet.AspNetCore.Client.Tcp { public class SocketAwaitable : ICriticalNotifyCompletion { diff --git a/Tests/MQTTnet.Benchmarks/Tcp/SocketReceiver.cs b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs similarity index 96% rename from Tests/MQTTnet.Benchmarks/Tcp/SocketReceiver.cs rename to Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs index bc8e5c0..219b722 100644 --- a/Tests/MQTTnet.Benchmarks/Tcp/SocketReceiver.cs +++ b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs @@ -2,7 +2,7 @@ using System.IO.Pipelines; using System.Net.Sockets; -namespace MQTTnet.Benchmarks.Tcp +namespace MQTTnet.AspNetCore.Client.Tcp { public class SocketReceiver { diff --git a/Tests/MQTTnet.Benchmarks/Tcp/SocketSender.cs b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs similarity index 98% rename from Tests/MQTTnet.Benchmarks/Tcp/SocketSender.cs rename to Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs index 7cb1bc1..c8ba832 100644 --- a/Tests/MQTTnet.Benchmarks/Tcp/SocketSender.cs +++ b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs @@ -5,7 +5,7 @@ using System.Diagnostics; using System.IO.Pipelines; using System.Net.Sockets; -namespace MQTTnet.Benchmarks.Tcp +namespace MQTTnet.AspNetCore.Client.Tcp { public class SocketSender { diff --git a/Tests/MQTTnet.Benchmarks/Tcp/TcpConnection.cs b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs similarity index 85% rename from Tests/MQTTnet.Benchmarks/Tcp/TcpConnection.cs rename to Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs index 83cb98a..40ba295 100644 --- a/Tests/MQTTnet.Benchmarks/Tcp/TcpConnection.cs +++ b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs @@ -4,27 +4,35 @@ using System.IO.Pipelines; using System.Net; using System.Net.Sockets; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; using MQTTnet.Exceptions; -namespace MQTTnet.Benchmarks.Tcp +namespace MQTTnet.AspNetCore.Client.Tcp { - public class TcpConnection + public class TcpConnection : ConnectionContext { - private readonly Socket _socket; private volatile bool _aborted; private readonly EndPoint _endPoint; + private SocketSender _sender; + private SocketReceiver _receiver; + + private Socket _socket; private IDuplexPipe _application; - private IDuplexPipe _transport; - private readonly SocketSender _sender; - private readonly SocketReceiver _receiver; + + + public bool IsConnected { get; private set; } + + public override string ConnectionId { get; set; } + + public override IFeatureCollection Features { get; } + + public override IDictionary Items { get; set; } + public override IDuplexPipe Transport { get; set; } public TcpConnection(EndPoint endPoint) { - _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); _endPoint = endPoint; - - _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); - _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); } public TcpConnection(Socket socket) @@ -38,29 +46,34 @@ namespace MQTTnet.Benchmarks.Tcp public Task DisposeAsync() { - _transport?.Output.Complete(); - _transport?.Input.Complete(); + IsConnected = false; + + Transport?.Output.Complete(); + Transport?.Input.Complete(); _socket?.Dispose(); return Task.CompletedTask; } - public async Task StartAsync() + public async Task StartAsync() { - if (!_socket.Connected) + if (_socket == null) { + _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); + _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); await _socket.ConnectAsync(_endPoint); } var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); - _transport = pair.Transport; + Transport = pair.Transport; _application = pair.Application; _ = ExecuteAsync(); - return pair.Transport; + IsConnected = true; } private async Task ExecuteAsync() diff --git a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs index 88bfcc9..f7acfa4 100644 --- a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -1,5 +1,6 @@ using Microsoft.AspNetCore.Connections; using MQTTnet.Adapter; +using MQTTnet.AspNetCore.Client.Tcp; using MQTTnet.Packets; using MQTTnet.Serializer; using System; @@ -27,6 +28,10 @@ namespace MQTTnet.AspNetCore public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { + if (Connection is TcpConnection tcp && !tcp.IsConnected) + { + return tcp.StartAsync(); + } return Task.CompletedTask; } @@ -34,6 +39,7 @@ namespace MQTTnet.AspNetCore { Connection.Transport.Input.Complete(); Connection.Transport.Output.Complete(); + return Task.CompletedTask; } diff --git a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs index 5b83253..0411710 100644 --- a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs +++ b/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs @@ -76,13 +76,10 @@ namespace MQTTnet.AspNetCore } var bodySlice = copy.Slice(0, bodyLength); - using (var body = new MemoryStream(bodySlice.GetArray())) - { - packet = serializer.Deserialize(new ReceivedMqttPacket(fixedheader, body)); - consumed = bodySlice.End; - observed = bodySlice.End; - return true; - } + packet = serializer.Deserialize(new ReceivedMqttPacket(fixedheader, new MqttPacketBodyReader(bodySlice.GetArray()))); + consumed = bodySlice.End; + observed = bodySlice.End; + return true; } } } diff --git a/Source/MQTTnet/MqttFactory.cs b/Source/MQTTnet/MqttFactory.cs index 00438fa..e830ad2 100644 --- a/Source/MQTTnet/MqttFactory.cs +++ b/Source/MQTTnet/MqttFactory.cs @@ -22,6 +22,14 @@ namespace MQTTnet return new MqttClient(new MqttClientAdapterFactory(), logger); } + public IMqttClient CreateMqttClient(IMqttNetLogger logger, IMqttClientAdapterFactory mqttClientAdapterFactory) + { + if (logger == null) throw new ArgumentNullException(nameof(logger)); + if (mqttClientAdapterFactory == null) throw new ArgumentNullException(nameof(mqttClientAdapterFactory)); + + return new MqttClient(mqttClientAdapterFactory, logger); + } + public IMqttServer CreateMqttServer() { var logger = new MqttNetLogger(); diff --git a/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj index 5e612fb..2a6b91a 100644 --- a/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj +++ b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -10,9 +10,11 @@ + + diff --git a/Tests/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Tests/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs new file mode 100644 index 0000000..1e2fbb2 --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs @@ -0,0 +1,76 @@ +using BenchmarkDotNet.Attributes; +using MQTTnet.Client; + + +using MQTTnet.AspNetCore; + +using Microsoft.AspNetCore; +using Microsoft.AspNetCore.Hosting; +using MQTTnet.Server; +using MQTTnet.Diagnostics; +using MQTTnet.AspNetCore.Client; + +namespace MQTTnet.Benchmarks +{ + [MemoryDiagnoser] + public class MessageProcessingMqttConnectionContextBenchmark + { + private IWebHost _host; + private IMqttClient _mqttClient; + private MqttApplicationMessage _message; + + [GlobalSetup] + public void Setup() + { + _host = WebHost.CreateDefaultBuilder() + .UseKestrel(o => o.ListenAnyIP(1883, l => l.UseMqtt())) + .ConfigureServices(services => { + var mqttServerOptions = new MqttServerOptionsBuilder() + .WithoutDefaultEndpoint() + .Build(); + services + .AddHostedMqttServer(mqttServerOptions) + .AddMqttConnectionHandler(); + }) + .Configure(app => { + app.UseMqttServer(s => { + + }); + }) + .Build(); + + var factory = new MqttFactory(); + _mqttClient = factory.CreateMqttClient(new MqttNetLogger(), new MqttClientConnectionContextFactory()); + + _host.StartAsync().GetAwaiter().GetResult(); + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost").Build(); + + _mqttClient.ConnectAsync(clientOptions).GetAwaiter().GetResult(); + + _message = new MqttApplicationMessageBuilder() + .WithTopic("A") + .Build(); + } + + [GlobalCleanup] + public void Cleanup() + { + _mqttClient.DisconnectAsync().GetAwaiter().GetResult(); + _mqttClient.Dispose(); + + _host.StopAsync().GetAwaiter().GetResult(); + _host.Dispose(); + } + + [Benchmark] + public void Send_10000_Messages() + { + for (var i = 0; i < 10000; i++) + { + _mqttClient.PublishAsync(_message).GetAwaiter().GetResult(); + } + } + } +} diff --git a/Tests/MQTTnet.Benchmarks/Program.cs b/Tests/MQTTnet.Benchmarks/Program.cs index 5306916..9407ccf 100644 --- a/Tests/MQTTnet.Benchmarks/Program.cs +++ b/Tests/MQTTnet.Benchmarks/Program.cs @@ -16,6 +16,7 @@ namespace MQTTnet.Benchmarks Console.WriteLine("5 = ChannelAdapterBenchmark"); Console.WriteLine("6 = MqttTcpChannelBenchmark"); Console.WriteLine("7 = TcpPipesBenchmark"); + Console.WriteLine("8 = MessageProcessingMqttConnectionContextBenchmark"); var pressedKey = Console.ReadKey(true); switch (pressedKey.KeyChar) @@ -41,6 +42,9 @@ namespace MQTTnet.Benchmarks case '7': BenchmarkRunner.Run(); break; + case '8': + BenchmarkRunner.Run(new AllowNonOptimized()); + break; } Console.ReadLine(); diff --git a/Tests/MQTTnet.Benchmarks/TcpPipesBenchmark.cs b/Tests/MQTTnet.Benchmarks/TcpPipesBenchmark.cs index cb3ffba..4252ab3 100644 --- a/Tests/MQTTnet.Benchmarks/TcpPipesBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/TcpPipesBenchmark.cs @@ -4,7 +4,7 @@ using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using BenchmarkDotNet.Attributes; -using MQTTnet.Benchmarks.Tcp; +using MQTTnet.AspNetCore.Client.Tcp; namespace MQTTnet.Benchmarks { @@ -25,10 +25,12 @@ namespace MQTTnet.Benchmarks var clientConnection = new TcpConnection(new IPEndPoint(IPAddress.Loopback, 1883)); - _client = clientConnection.StartAsync().GetAwaiter().GetResult(); + clientConnection.StartAsync().GetAwaiter().GetResult(); + _client = clientConnection.Transport; var serverConnection = new TcpConnection(task.GetAwaiter().GetResult()); - _server = serverConnection.StartAsync().GetAwaiter().GetResult(); + serverConnection.StartAsync().GetAwaiter().GetResult(); + _server = serverConnection.Transport; } diff --git a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj index f6ead87..a216d72 100644 --- a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj +++ b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj @@ -10,6 +10,7 @@ + From c32a606a7064aef784f9042604b0af16d5d3e8b1 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sat, 2 Jun 2018 11:53:58 +0200 Subject: [PATCH 07/18] add memory diagnoser --- Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs index 53b71fd..f6ff61a 100644 --- a/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs @@ -9,6 +9,7 @@ namespace MQTTnet.Benchmarks { [ClrJob] [RPlotExporter, RankColumn] + [MemoryDiagnoser] public class MessageProcessingBenchmark { private IMqttServer _mqttServer; From cf947415e8c9b2e4bcd275642042ce030aaf5089 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sat, 2 Jun 2018 11:54:22 +0200 Subject: [PATCH 08/18] enable abort exceptions --- .../MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs index 40ba295..4bf74fb 100644 --- a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs +++ b/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs @@ -131,14 +131,14 @@ namespace MQTTnet.AspNetCore.Client.Tcp if (!_aborted) { // Calling Dispose after ReceiveAsync can cause an "InvalidArgument" error on *nix. - //error = new MqttCommunicationException(); + error = ConnectionAborted(); } } catch (ObjectDisposedException) { if (!_aborted) { - //error = new MqttCommunicationException(); + error = ConnectionAborted(); } } catch (IOException ex) @@ -153,7 +153,7 @@ namespace MQTTnet.AspNetCore.Client.Tcp { if (_aborted) { - //error = error ?? new MqttCommunicationException(); + error = error ?? ConnectionAborted(); } _application.Output.Complete(error); @@ -193,6 +193,11 @@ namespace MQTTnet.AspNetCore.Client.Tcp } } + private Exception ConnectionAborted() + { + return new MqttCommunicationException("Connection Aborted"); + } + private async Task DoSend() { Exception error = null; From e65c3f2f2b84b47952f15824cfe961f897a21d43 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sat, 16 Jun 2018 17:30:08 +0200 Subject: [PATCH 09/18] moved files into the correct place --- .../MqttConnectionContext.cs | 114 ----------------- .../MqttClientConnectionContextFactory.cs | 0 .../Client/Tcp/BufferExtensions.cs | 0 .../Client/Tcp/DuplexPipe.cs | 0 .../Client/Tcp/SocketAwaitable.cs | 0 .../Client/Tcp/SocketReceiver.cs | 0 .../Client/Tcp/SocketSender.cs | 0 .../Client/Tcp/TcpConnection.cs | 1 + .../ConnectionBuilderExtensions.cs | 0 .../MqttConnectionContext.cs | 120 ++++++++++++++++++ .../MqttConnectionHandler.cs | 0 .../MQTTnet.AspnetCore/ReaderExtensions.cs | 0 12 files changed, 121 insertions(+), 114 deletions(-) delete mode 100644 Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs rename {Frameworks => Source}/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs (100%) rename {Frameworks => Source}/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs (100%) rename {Frameworks => Source}/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs (100%) rename {Frameworks => Source}/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs (100%) rename {Frameworks => Source}/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs (100%) rename {Frameworks => Source}/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs (100%) rename {Frameworks => Source}/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs (99%) rename {Frameworks => Source}/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs (100%) create mode 100644 Source/MQTTnet.AspnetCore/MqttConnectionContext.cs rename {Frameworks => Source}/MQTTnet.AspnetCore/MqttConnectionHandler.cs (100%) rename {Frameworks => Source}/MQTTnet.AspnetCore/ReaderExtensions.cs (100%) diff --git a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs deleted file mode 100644 index f7acfa4..0000000 --- a/Frameworks/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ /dev/null @@ -1,114 +0,0 @@ -using Microsoft.AspNetCore.Connections; -using MQTTnet.Adapter; -using MQTTnet.AspNetCore.Client.Tcp; -using MQTTnet.Packets; -using MQTTnet.Serializer; -using System; -using System.Collections.Generic; -using System.IO.Pipelines; -using System.Threading; -using System.Threading.Tasks; - -namespace MQTTnet.AspNetCore -{ - public class MqttConnectionContext : IMqttChannelAdapter - { - public IMqttPacketSerializer PacketSerializer { get; } - public ConnectionContext Connection { get; } - - public string Endpoint => Connection.ConnectionId; - - public MqttConnectionContext( - IMqttPacketSerializer packetSerializer, - ConnectionContext connection) - { - PacketSerializer = packetSerializer; - Connection = connection; - } - - public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) - { - if (Connection is TcpConnection tcp && !tcp.IsConnected) - { - return tcp.StartAsync(); - } - return Task.CompletedTask; - } - - public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) - { - Connection.Transport.Input.Complete(); - Connection.Transport.Output.Complete(); - - return Task.CompletedTask; - } - - public void Dispose() - { - } - - public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) - { - var input = Connection.Transport.Input; - - while (!cancellationToken.IsCancellationRequested) - { - ReadResult readResult; - ReadingPacketStarted?.Invoke(this, EventArgs.Empty); - - var readTask = input.ReadAsync(cancellationToken); - if (readTask.IsCompleted) - { - readResult = readTask.Result; - } - else - { - readResult = await readTask; - } - - var buffer = readResult.Buffer; - - var consumed = buffer.Start; - var observed = buffer.Start; - - try - { - if (!buffer.IsEmpty) - { - if (PacketSerializer.TryDeserialize(buffer, out var packet, out consumed, out observed)) - { - return packet; - } - } - else if (readResult.IsCompleted) - { - break; - } - } - finally - { - // The buffer was sliced up to where it was consumed, so we can just advance to the start. - // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data - // before yielding the read again. - input.AdvanceTo(consumed, observed); - ReadingPacketCompleted?.Invoke(this, EventArgs.Empty); - } - } - - cancellationToken.ThrowIfCancellationRequested(); - return null; - } - - public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets, CancellationToken cancellationToken) - { - foreach (var packet in packets) - { - var buffer = PacketSerializer.Serialize(packet); - await Connection.Transport.Output.WriteAsync(buffer.AsMemory()); - } - } - - public event EventHandler ReadingPacketStarted; - public event EventHandler ReadingPacketCompleted; - } -} diff --git a/Frameworks/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs b/Source/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs rename to Source/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs diff --git a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs rename to Source/MQTTnet.AspnetCore/Client/Tcp/BufferExtensions.cs diff --git a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs rename to Source/MQTTnet.AspnetCore/Client/Tcp/DuplexPipe.cs diff --git a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs rename to Source/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs diff --git a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs rename to Source/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs diff --git a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs rename to Source/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs diff --git a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs similarity index 99% rename from Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs rename to Source/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs index 4bf74fb..7417e74 100644 --- a/Frameworks/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs +++ b/Source/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Net; diff --git a/Frameworks/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs rename to Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs new file mode 100644 index 0000000..3d42937 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -0,0 +1,120 @@ +using Microsoft.AspNetCore.Connections; +using MQTTnet.Adapter; +using MQTTnet.AspNetCore.Client.Tcp; +using MQTTnet.Packets; +using MQTTnet.Serializer; +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + public class MqttConnectionContext : IMqttChannelAdapter + { + public IMqttPacketSerializer PacketSerializer { get; } + public ConnectionContext Connection { get; } + + public string Endpoint => Connection.ConnectionId; + + public MqttConnectionContext( + IMqttPacketSerializer packetSerializer, + ConnectionContext connection) + { + PacketSerializer = packetSerializer; + Connection = connection; + } + + public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + if (Connection is TcpConnection tcp && !tcp.IsConnected) + { + return tcp.StartAsync(); + } + return Task.CompletedTask; + } + + public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + Connection.Transport.Input.Complete(); + Connection.Transport.Output.Complete(); + + return Task.CompletedTask; + } + + public void Dispose() + { + } + + public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + var input = Connection.Transport.Input; + + try + { + while (!cancellationToken.IsCancellationRequested) + { + ReadResult readResult; + var readTask = input.ReadAsync(cancellationToken); + if (readTask.IsCompleted) + { + readResult = readTask.Result; + } + else + { + readResult = await readTask; + } + + var buffer = readResult.Buffer; + + var consumed = buffer.Start; + var observed = buffer.Start; + + try + { + if (!buffer.IsEmpty) + { + if (PacketSerializer.TryDeserialize(buffer, out var packet, out consumed, out observed)) + { + return packet; + } + else + { + // we did receive something but the message is not yet complete + ReadingPacketStarted?.Invoke(this, EventArgs.Empty); + } + } + else if (readResult.IsCompleted) + { + break; + } + } + finally + { + // The buffer was sliced up to where it was consumed, so we can just advance to the start. + // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data + // before yielding the read again. + input.AdvanceTo(consumed, observed); + } + } + } + finally + { + ReadingPacketCompleted?.Invoke(this, EventArgs.Empty); + } + + cancellationToken.ThrowIfCancellationRequested(); + return null; + } + + public Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) + { + var buffer = PacketSerializer.Serialize(packet); + return Connection.Transport.Output.WriteAsync(buffer.AsMemory(), cancellationToken).AsTask(); + } + + public event EventHandler ReadingPacketStarted; + public event EventHandler ReadingPacketCompleted; + } +} diff --git a/Frameworks/MQTTnet.AspnetCore/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/MqttConnectionHandler.cs rename to Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs diff --git a/Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs b/Source/MQTTnet.AspnetCore/ReaderExtensions.cs similarity index 100% rename from Frameworks/MQTTnet.AspnetCore/ReaderExtensions.cs rename to Source/MQTTnet.AspnetCore/ReaderExtensions.cs From 410bc211a546f7d2e424b45270c0e13c2825ff22 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sat, 16 Jun 2018 17:39:04 +0200 Subject: [PATCH 10/18] fixed packet reader --- Source/MQTTnet/Serializer/MqttPacketReader.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Source/MQTTnet/Serializer/MqttPacketReader.cs b/Source/MQTTnet/Serializer/MqttPacketReader.cs index d50fae5..7932956 100644 --- a/Source/MQTTnet/Serializer/MqttPacketReader.cs +++ b/Source/MQTTnet/Serializer/MqttPacketReader.cs @@ -19,7 +19,7 @@ namespace MQTTnet.Serializer while (totalBytesRead < buffer.Length) { - var bytesRead = await channel.ReadAsync(buffer, 0, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); + var bytesRead = await channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); if (bytesRead <= 0) { if (cancellationToken.IsCancellationRequested) From d573423372fa57ba2830c73ea866efdaad7d10fc Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sat, 16 Jun 2018 17:39:37 +0200 Subject: [PATCH 11/18] downgrade dependencies to 2.0 --- Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj | 4 ++-- Tests/MQTTnet.TestApp.AspNetCore2/Program.cs | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 9754f40..8db250b 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -18,8 +18,8 @@ - - + + diff --git a/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs b/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs index a716411..5248ce6 100644 --- a/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs +++ b/Tests/MQTTnet.TestApp.AspNetCore2/Program.cs @@ -13,7 +13,10 @@ namespace MQTTnet.TestApp.AspNetCore2 private static IWebHost BuildWebHost(string[] args) => WebHost.CreateDefaultBuilder(args) - .UseKestrel(o => o.ListenAnyIP(1883, l => l.UseMqtt())) + .UseKestrel(o => { + o.ListenAnyIP(1883, l => l.UseMqtt()); + o.ListenAnyIP(5000); // default http pipeline + }) .UseStartup() .Build(); } From beb54acc2fd4f0f7148830040e49bede55e6371c Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Sun, 17 Jun 2018 13:19:43 +0200 Subject: [PATCH 12/18] Refactor serializer benchmark to allow fair comparisons. --- Source/MQTTnet/Adapter/MqttChannelAdapter.cs | 2 +- .../Serializer/MqttPacketBodyReader.cs | 3 ++- .../MQTTnet.Benchmarks/SerializerBenchmark.cs | 24 +++++-------------- .../MqttPacketSerializerTests.cs | 2 +- 4 files changed, 10 insertions(+), 21 deletions(-) diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 7db9a2a..b21f4b3 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -192,7 +192,7 @@ namespace MQTTnet.Adapter bodyOffset += readBytes; } while (bodyOffset < body.Length); - return new ReceivedMqttPacket(fixedHeader.Flags, new MqttPacketBodyReader(body)); + return new ReceivedMqttPacket(fixedHeader.Flags, new MqttPacketBodyReader(body, 0)); } finally { diff --git a/Source/MQTTnet/Serializer/MqttPacketBodyReader.cs b/Source/MQTTnet/Serializer/MqttPacketBodyReader.cs index 04c0a22..d751c68 100644 --- a/Source/MQTTnet/Serializer/MqttPacketBodyReader.cs +++ b/Source/MQTTnet/Serializer/MqttPacketBodyReader.cs @@ -8,9 +8,10 @@ namespace MQTTnet.Serializer private readonly byte[] _buffer; private int _offset; - public MqttPacketBodyReader(byte[] buffer) + public MqttPacketBodyReader(byte[] buffer, int offset) { _buffer = buffer; + _offset = offset; } public int Length => _buffer.Length - _offset; diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs index 06e7b39..5696c6a 100644 --- a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -46,29 +46,17 @@ namespace MQTTnet.Benchmarks { for (var i = 0; i < 10000; i++) { - using (var headerStream = new MemoryStream(Join(_serializedPacket))) + using (var stream = new MemoryStream()) { - var channel = new TestMqttChannel(headerStream); + stream.Write(_serializedPacket.Array, _serializedPacket.Offset, _serializedPacket.Count); + stream.Position = 0; + + var channel = new TestMqttChannel(stream); var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); - - using (var bodyStream = new MemoryStream(Join(_serializedPacket), (int)headerStream.Position, header.RemainingLength)) - { - _serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(bodyStream.ToArray()))); - } + _serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(_serializedPacket.Array, _serializedPacket.Count - header.RemainingLength))); } } } - - private static byte[] Join(params ArraySegment[] chunks) - { - var buffer = new MemoryStream(); - foreach (var chunk in chunks) - { - buffer.Write(chunk.Array, chunk.Offset, chunk.Count); - } - - return buffer.ToArray(); - } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 40aa953..0a0fa2b 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -422,7 +422,7 @@ namespace MQTTnet.Core.Tests using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) { - var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(bodyStream.ToArray()))); + var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(bodyStream.ToArray(), 0))); var buffer2 = serializer.Serialize(deserializedPacket); Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2))); From 56e9e15ac006876cac3ff4d97607a0f4e0349b92 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Sun, 17 Jun 2018 13:47:58 +0200 Subject: [PATCH 13/18] Refactor Serializer memory usage. --- Source/MQTTnet/Adapter/MqttChannelAdapter.cs | 2 + .../Serializer/IMqttPacketSerializer.cs | 2 + .../Serializer/MqttPacketSerializer.cs | 73 ++++++++++++------- Source/MQTTnet/Serializer/MqttPacketWriter.cs | 21 ++++++ .../MQTTnet.Benchmarks/SerializerBenchmark.cs | 55 +++++++++++--- 5 files changed, 117 insertions(+), 36 deletions(-) diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index b21f4b3..cfbf743 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -95,6 +95,8 @@ namespace MQTTnet.Adapter var packetData = PacketSerializer.Serialize(packet); await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false); + + PacketSerializer.FreeBuffer(); } catch (Exception exception) { diff --git a/Source/MQTTnet/Serializer/IMqttPacketSerializer.cs b/Source/MQTTnet/Serializer/IMqttPacketSerializer.cs index a81071b..e3066f9 100644 --- a/Source/MQTTnet/Serializer/IMqttPacketSerializer.cs +++ b/Source/MQTTnet/Serializer/IMqttPacketSerializer.cs @@ -11,5 +11,7 @@ namespace MQTTnet.Serializer ArraySegment Serialize(MqttBasePacket mqttPacket); MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket); + + void FreeBuffer(); } } \ No newline at end of file diff --git a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs index 811acf7..aafaa90 100644 --- a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs +++ b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs @@ -11,15 +11,19 @@ namespace MQTTnet.Serializer { private const int FixedHeaderSize = 1; + [ThreadStatic] + private static MqttPacketWriter _packetWriter; + public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; public ArraySegment Serialize(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); - var packetWriter = new MqttPacketWriter(); + var packetWriter = InitializePacketWriter(); // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) + packetWriter.Reset(); packetWriter.Seek(5); var fixedHeader = SerializePacket(packet, packetWriter); @@ -39,28 +43,6 @@ namespace MQTTnet.Serializer return new ArraySegment(buffer, headerOffset, packetWriter.Length - headerOffset); } - private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) - { - switch (packet) - { - case MqttConnectPacket connectPacket: return Serialize(connectPacket, packetWriter); - case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, packetWriter); - case MqttDisconnectPacket _: return SerializeEmptyPacket(MqttControlPacketType.Disconnect); - case MqttPingReqPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingReq); - case MqttPingRespPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingResp); - case MqttPublishPacket publishPacket: return Serialize(publishPacket, packetWriter); - case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, packetWriter); - case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, packetWriter); - case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, packetWriter); - case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, packetWriter); - case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, packetWriter); - case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, packetWriter); - case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, packetWriter); - case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, packetWriter); - default: throw new MqttProtocolViolationException("Packet type invalid."); - } - } - public MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket) { if (receivedMqttPacket == null) throw new ArgumentNullException(nameof(receivedMqttPacket)); @@ -92,6 +74,43 @@ namespace MQTTnet.Serializer } } + public void FreeBuffer() + { + InitializePacketWriter().FreeBuffer(); + } + + private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) + { + switch (packet) + { + case MqttConnectPacket connectPacket: return Serialize(connectPacket, packetWriter); + case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, packetWriter); + case MqttDisconnectPacket _: return SerializeEmptyPacket(MqttControlPacketType.Disconnect); + case MqttPingReqPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingReq); + case MqttPingRespPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingResp); + case MqttPublishPacket publishPacket: return Serialize(publishPacket, packetWriter); + case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, packetWriter); + case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, packetWriter); + case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, packetWriter); + case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, packetWriter); + case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, packetWriter); + case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, packetWriter); + case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, packetWriter); + case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, packetWriter); + default: throw new MqttProtocolViolationException("Packet type invalid."); + } + } + + private static MqttPacketWriter InitializePacketWriter() + { + if (_packetWriter == null) + { + _packetWriter = new MqttPacketWriter(); + } + + return _packetWriter; + } + private static MqttBasePacket DeserializeUnsubAck(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -310,7 +329,7 @@ namespace MQTTnet.Serializer var packet = new MqttConnAckPacket(); var acknowledgeFlags = body.ReadByte(); - + if (ProtocolVersion == MqttProtocolVersion.V311) { packet.IsSessionPresent = (acknowledgeFlags & 0x1) > 0; @@ -372,7 +391,7 @@ namespace MQTTnet.Serializer connectFlags |= 0x20; } } - + if (packet.Password != null && packet.Username == null) { throw new MqttProtocolViolationException("If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22]."); @@ -387,7 +406,7 @@ namespace MQTTnet.Serializer { connectFlags |= 0x80; } - + packetWriter.Write(connectFlags); packetWriter.Write(packet.KeepAlivePeriod); packetWriter.WriteWithLengthPrefix(packet.ClientId); @@ -424,7 +443,7 @@ namespace MQTTnet.Serializer { connectAcknowledgeFlags |= 0x1; } - + packetWriter.Write(connectAcknowledgeFlags); } else diff --git a/Source/MQTTnet/Serializer/MqttPacketWriter.cs b/Source/MQTTnet/Serializer/MqttPacketWriter.cs index c0c49fc..a2c6f5b 100644 --- a/Source/MQTTnet/Serializer/MqttPacketWriter.cs +++ b/Source/MQTTnet/Serializer/MqttPacketWriter.cs @@ -93,6 +93,11 @@ namespace MQTTnet.Serializer IncreasePostition(count); } + public void Reset() + { + Length = 0; + } + public void Seek(int offset) { EnsureCapacity(offset); @@ -104,6 +109,22 @@ namespace MQTTnet.Serializer return _buffer; } + public void FreeBuffer() + { + // This method frees the used memory by shrinking the buffer. This is required because the buffer + // is used across several messages. In general this is not a big issue because subsequent Ping packages + // have the same size but a very big publish package with 100 MB of payload will increase the buffer + // a lot and the size will never reduced. So this method tries to find a size which can be held in + // memory for a long time without causing troubles. + + if (_buffer.Length < 4096) + { + return; + } + + Array.Resize(ref _buffer, 4096); + } + private void EnsureAdditionalCapacity(int additionalCapacity) { var freeSpace = _buffer.Length - _position; diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs index 5696c6a..2e9f22c 100644 --- a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -6,8 +6,9 @@ using BenchmarkDotNet.Attributes.Jobs; using BenchmarkDotNet.Attributes.Exporters; using System; using System.Threading; -using System.IO; +using System.Threading.Tasks; using MQTTnet.Adapter; +using MQTTnet.Channel; namespace MQTTnet.Benchmarks { @@ -38,6 +39,7 @@ namespace MQTTnet.Benchmarks for (var i = 0; i < 10000; i++) { _serializer.Serialize(_packet); + _serializer.FreeBuffer(); } } @@ -46,16 +48,51 @@ namespace MQTTnet.Benchmarks { for (var i = 0; i < 10000; i++) { - using (var stream = new MemoryStream()) - { - stream.Write(_serializedPacket.Array, _serializedPacket.Offset, _serializedPacket.Count); - stream.Position = 0; + var channel = new BenchmarkMqttChannel(_serializedPacket); - var channel = new TestMqttChannel(stream); + var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); + _serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(_serializedPacket.Array, _serializedPacket.Count - header.RemainingLength))); + } + } + + private class BenchmarkMqttChannel : IMqttChannel + { + private readonly ArraySegment _buffer; + private int _position; + + public BenchmarkMqttChannel(ArraySegment buffer) + { + _buffer = buffer; + _position = _buffer.Offset; + } + + public string Endpoint { get; } + + public Task ConnectAsync(CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); - _serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(_serializedPacket.Array, _serializedPacket.Count - header.RemainingLength))); - } + public Task DisconnectAsync() + { + throw new NotImplementedException(); + } + + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Array.Copy(_buffer.Array, _position, buffer, offset, count); + _position += count; + + return Task.FromResult(count); + } + + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public void Dispose() + { } } } From baac26c5f6f9033803af69d1b5daa72f4a076c7b Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Sun, 17 Jun 2018 14:36:32 +0200 Subject: [PATCH 14/18] Refactor package deserialization. --- Source/MQTTnet/Adapter/MqttChannelAdapter.cs | 4 --- Source/MQTTnet/Serializer/Extensions.cs | 5 ++- Source/MQTTnet/Serializer/MqttFixedHeader.cs | 4 +-- Source/MQTTnet/Serializer/MqttPacketReader.cs | 33 ++++++++++++++----- .../Serializer/MqttPacketSerializer.cs | 9 +++-- .../MessageProcessingBenchmark.cs | 1 + .../MQTTnet.Benchmarks/SerializerBenchmark.cs | 16 +++++++-- 7 files changed, 50 insertions(+), 22 deletions(-) diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index cfbf743..94e0f33 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -157,10 +157,6 @@ namespace MQTTnet.Adapter private async Task ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) { var fixedHeader = await MqttPacketReader.ReadFixedHeaderAsync(channel, cancellationToken).ConfigureAwait(false); - if (fixedHeader == null) - { - return null; - } try { diff --git a/Source/MQTTnet/Serializer/Extensions.cs b/Source/MQTTnet/Serializer/Extensions.cs index 1a14de4..00ad7cf 100644 --- a/Source/MQTTnet/Serializer/Extensions.cs +++ b/Source/MQTTnet/Serializer/Extensions.cs @@ -12,7 +12,10 @@ namespace MQTTnet.Serializer } var buffer = new byte[source.Count]; - Buffer.BlockCopy(source.Array, source.Offset, buffer, 0, buffer.Length); + if (buffer.Length > 0) + { + Array.Copy(source.Array, source.Offset, buffer, 0, buffer.Length); + } return buffer; } diff --git a/Source/MQTTnet/Serializer/MqttFixedHeader.cs b/Source/MQTTnet/Serializer/MqttFixedHeader.cs index d87f63d..a8c2015 100644 --- a/Source/MQTTnet/Serializer/MqttFixedHeader.cs +++ b/Source/MQTTnet/Serializer/MqttFixedHeader.cs @@ -1,6 +1,6 @@ namespace MQTTnet.Serializer { - public class MqttFixedHeader + public struct MqttFixedHeader { public MqttFixedHeader(byte flags, int remainingLength) { @@ -10,6 +10,6 @@ public byte Flags { get; } - public int RemainingLength { get; set; } + public int RemainingLength { get; } } } diff --git a/Source/MQTTnet/Serializer/MqttPacketReader.cs b/Source/MQTTnet/Serializer/MqttPacketReader.cs index d50fae5..5bbed52 100644 --- a/Source/MQTTnet/Serializer/MqttPacketReader.cs +++ b/Source/MQTTnet/Serializer/MqttPacketReader.cs @@ -1,4 +1,5 @@ -using System.Threading; +using System; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; using MQTTnet.Exceptions; @@ -8,23 +9,29 @@ namespace MQTTnet.Serializer { public static class MqttPacketReader { + [ThreadStatic] + private static byte[] _fixedHeaderBuffer; + + [ThreadStatic] + private static byte[] _singleByteBuffer; + public static async Task ReadFixedHeaderAsync(IMqttChannel channel, CancellationToken cancellationToken) { // The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. // So in all cases at least 2 bytes must be read for a complete MQTT packet. // async/await is used here because the next packet is received in a couple of minutes so the performance // impact is acceptable according to a useless waiting thread. - var buffer = new byte[2]; + var buffer = InitializeFixedHeaderBuffer(); var totalBytesRead = 0; while (totalBytesRead < buffer.Length) { - var bytesRead = await channel.ReadAsync(buffer, 0, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); + var bytesRead = await channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); if (bytesRead <= 0) { if (cancellationToken.IsCancellationRequested) { - return null; + throw new TaskCanceledException(); } ExceptionHelper.ThrowGracefulSocketClose(); @@ -56,7 +63,7 @@ namespace MQTTnet.Serializer // is too big for reading 1 byte in a row. We expect that the remaining data was sent // directly after the initial bytes. If the client disconnects just in this moment we // will get an exception anyway. - encodedByte = ReadByteAsync(channel, cancellationToken).GetAwaiter().GetResult(); + encodedByte = ReadByte(channel, cancellationToken); value += (byte)(encodedByte & 127) * multiplier; if (multiplier > 128 * 128 * 128) @@ -70,10 +77,10 @@ namespace MQTTnet.Serializer return value; } - private static async Task ReadByteAsync(IMqttChannel channel, CancellationToken cancellationToken) + private static byte ReadByte(IMqttChannel channel, CancellationToken cancellationToken) { - var buffer = new byte[1]; - var readCount = await channel.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); + var buffer = InitializeSingleByteBuffer(); + var readCount = channel.ReadAsync(buffer, 0, 1, cancellationToken).GetAwaiter().GetResult(); if (readCount <= 0) { ExceptionHelper.ThrowGracefulSocketClose(); @@ -81,5 +88,15 @@ namespace MQTTnet.Serializer return buffer[0]; } + + private static byte[] InitializeFixedHeaderBuffer() + { + return _fixedHeaderBuffer ?? (_fixedHeaderBuffer = new byte[2]); + } + + private static byte[] InitializeSingleByteBuffer() + { + return _singleByteBuffer ?? (_singleByteBuffer = new byte[1]); + } } } diff --git a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs index aafaa90..159ba89 100644 --- a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs +++ b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs @@ -199,19 +199,18 @@ namespace MQTTnet.Serializer private static MqttBasePacket DeserializePublish(ReceivedMqttPacket receivedMqttPacket) { - var body = receivedMqttPacket.Body; - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(receivedMqttPacket.Body); var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); var dup = (receivedMqttPacket.FixedHeader & 0x3) > 0; - var topic = body.ReadStringWithLengthPrefix(); + var topic = receivedMqttPacket.Body.ReadStringWithLengthPrefix(); ushort? packetIdentifier = null; if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - packetIdentifier = body.ReadUInt16(); + packetIdentifier = receivedMqttPacket.Body.ReadUInt16(); } var packet = new MqttPublishPacket @@ -219,7 +218,7 @@ namespace MQTTnet.Serializer PacketIdentifier = packetIdentifier, Retain = retain, Topic = topic, - Payload = body.ReadRemainingData().ToArray(), + Payload = receivedMqttPacket.Body.ReadRemainingData().ToArray(), QualityOfServiceLevel = qualityOfServiceLevel, Dup = dup }; diff --git a/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs index 53b71fd..f6ff61a 100644 --- a/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs @@ -9,6 +9,7 @@ namespace MQTTnet.Benchmarks { [ClrJob] [RPlotExporter, RankColumn] + [MemoryDiagnoser] public class MessageProcessingBenchmark { private IMqttServer _mqttServer; diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs index 2e9f22c..28de60f 100644 --- a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -46,12 +46,19 @@ namespace MQTTnet.Benchmarks [Benchmark] public void Deserialize_10000_Messages() { + var channel = new BenchmarkMqttChannel(_serializedPacket); + for (var i = 0; i < 10000; i++) { - var channel = new BenchmarkMqttChannel(_serializedPacket); + channel.Reset(); var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); - _serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(_serializedPacket.Array, _serializedPacket.Count - header.RemainingLength))); + + var receivedPacket = new ReceivedMqttPacket( + header.Flags, + new MqttPacketBodyReader(_serializedPacket.Array, _serializedPacket.Count - header.RemainingLength)); + + _serializer.Deserialize(receivedPacket); } } @@ -68,6 +75,11 @@ namespace MQTTnet.Benchmarks public string Endpoint { get; } + public void Reset() + { + _position = _buffer.Offset; + } + public Task ConnectAsync(CancellationToken cancellationToken) { throw new NotImplementedException(); From 66048931a2bae650582e638b010cee0104a797f8 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Sun, 17 Jun 2018 17:19:08 +0200 Subject: [PATCH 15/18] Refactor ASP net Core integration and fix breaking change. --- .../Client/Tcp/SocketAwaitable.cs | 3 +- .../Client/Tcp/SocketReceiver.cs | 1 - .../Client/Tcp/SocketSender.cs | 1 - .../Client/Tcp/TcpConnection.cs | 6 -- .../MqttConnectionContext.cs | 25 +++---- Source/MQTTnet.AspnetCore/MqttHostedServer.cs | 3 +- Source/MQTTnet.AspnetCore/ReaderExtensions.cs | 75 +++++++++---------- 7 files changed, 51 insertions(+), 63 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs index 96160d1..dbc2612 100644 --- a/Source/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs +++ b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketAwaitable.cs @@ -23,9 +23,10 @@ namespace MQTTnet.AspNetCore.Client.Tcp _ioScheduler = ioScheduler; } - public SocketAwaitable GetAwaiter() => this; public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); + public SocketAwaitable GetAwaiter() => this; + public int GetResult() { Debug.Assert(ReferenceEquals(_callback, _callbackCompleted)); diff --git a/Source/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs index 219b722..7d11fa2 100644 --- a/Source/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs +++ b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketReceiver.cs @@ -24,7 +24,6 @@ namespace MQTTnet.AspNetCore.Client.Tcp _eventArgs.SetBuffer(buffer); #else var segment = buffer.GetArray(); - _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); #endif if (!_socket.ReceiveAsync(_eventArgs)) diff --git a/Source/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs index c8ba832..55192d6 100644 --- a/Source/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs +++ b/Source/MQTTnet.AspnetCore/Client/Tcp/SocketSender.cs @@ -61,7 +61,6 @@ namespace MQTTnet.AspNetCore.Client.Tcp _eventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); #else var segment = memory.GetArray(); - _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); #endif if (!_socket.SendAsync(_eventArgs)) diff --git a/Source/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs b/Source/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs index 7417e74..37913fe 100644 --- a/Source/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs +++ b/Source/MQTTnet.AspnetCore/Client/Tcp/TcpConnection.cs @@ -21,13 +21,9 @@ namespace MQTTnet.AspNetCore.Client.Tcp private Socket _socket; private IDuplexPipe _application; - public bool IsConnected { get; private set; } - public override string ConnectionId { get; set; } - public override IFeatureCollection Features { get; } - public override IDictionary Items { get; set; } public override IDuplexPipe Transport { get; set; } @@ -209,11 +205,9 @@ namespace MQTTnet.AspNetCore.Client.Tcp } catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted) { - error = null; } catch (ObjectDisposedException) { - error = null; } catch (IOException ex) { diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index 3d42937..8a8469d 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -4,7 +4,6 @@ using MQTTnet.AspNetCore.Client.Tcp; using MQTTnet.Packets; using MQTTnet.Serializer; using System; -using System.Collections.Generic; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; @@ -13,11 +12,6 @@ namespace MQTTnet.AspNetCore { public class MqttConnectionContext : IMqttChannelAdapter { - public IMqttPacketSerializer PacketSerializer { get; } - public ConnectionContext Connection { get; } - - public string Endpoint => Connection.ConnectionId; - public MqttConnectionContext( IMqttPacketSerializer packetSerializer, ConnectionContext connection) @@ -26,6 +20,12 @@ namespace MQTTnet.AspNetCore Connection = connection; } + public string Endpoint => Connection.ConnectionId; + public ConnectionContext Connection { get; } + public IMqttPacketSerializer PacketSerializer { get; } + public event EventHandler ReadingPacketStarted; + public event EventHandler ReadingPacketCompleted; + public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { if (Connection is TcpConnection tcp && !tcp.IsConnected) @@ -43,10 +43,6 @@ namespace MQTTnet.AspNetCore return Task.CompletedTask; } - public void Dispose() - { - } - public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { var input = Connection.Transport.Input; @@ -106,15 +102,16 @@ namespace MQTTnet.AspNetCore cancellationToken.ThrowIfCancellationRequested(); return null; - } + } public Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { var buffer = PacketSerializer.Serialize(packet); return Connection.Transport.Output.WriteAsync(buffer.AsMemory(), cancellationToken).AsTask(); } - - public event EventHandler ReadingPacketStarted; - public event EventHandler ReadingPacketCompleted; + + public void Dispose() + { + } } } diff --git a/Source/MQTTnet.AspnetCore/MqttHostedServer.cs b/Source/MQTTnet.AspnetCore/MqttHostedServer.cs index 708d383..4b6b436 100644 --- a/Source/MQTTnet.AspnetCore/MqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/MqttHostedServer.cs @@ -13,7 +13,8 @@ namespace MQTTnet.AspNetCore { private readonly IMqttServerOptions _options; - public MqttHostedServer(IMqttServerOptions options, IEnumerable adapters, IMqttNetLogger logger) : base(adapters, logger.CreateChildLogger(nameof(MqttHostedServer))) + public MqttHostedServer(IMqttServerOptions options, IEnumerable adapters, IMqttNetLogger logger) + : base(adapters, logger.CreateChildLogger(nameof(MqttHostedServer))) { _options = options ?? throw new ArgumentNullException(nameof(options)); } diff --git a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs b/Source/MQTTnet.AspnetCore/ReaderExtensions.cs index 0411710..2b7d8a6 100644 --- a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ReaderExtensions.cs @@ -1,6 +1,5 @@ using System; using System.Buffers; -using System.IO; using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Packets; @@ -8,8 +7,43 @@ using MQTTnet.Serializer; namespace MQTTnet.AspNetCore { - public static class ReaderExtensions + public static class ReaderExtensions { + public static bool TryDeserialize(this IMqttPacketSerializer serializer, in ReadOnlySequence input, out MqttBasePacket packet, out SequencePosition consumed, out SequencePosition observed) + { + packet = null; + consumed = input.Start; + observed = input.End; + var copy = input; + if (copy.Length < 2) + { + return false; + } + + var fixedheader = copy.First.Span[0]; + if (!TryReadBodyLength(ref copy, out var bodyLength)) + { + return false; + } + + var bodySlice = copy.Slice(0, bodyLength); + packet = serializer.Deserialize(new ReceivedMqttPacket(fixedheader, new MqttPacketBodyReader(bodySlice.GetArray(), 0))); + consumed = bodySlice.End; + observed = bodySlice.End; + return true; + } + + private static byte[] GetArray(this in ReadOnlySequence input) + { + if (input.IsSingleSegment) + { + return input.First.Span.ToArray(); + } + + // Should be rare + return input.ToArray(); + } + private static bool TryReadBodyLength(ref ReadOnlySequence input, out int result) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. @@ -44,42 +78,5 @@ namespace MQTTnet.AspNetCore result = value; return true; } - - - - public static byte[] GetArray(this in ReadOnlySequence input) - { - if (input.IsSingleSegment) - { - return input.First.Span.ToArray(); - } - - // Should be rare - return input.ToArray(); - } - - public static bool TryDeserialize(this IMqttPacketSerializer serializer, in ReadOnlySequence input, out MqttBasePacket packet, out SequencePosition consumed, out SequencePosition observed) - { - packet = null; - consumed = input.Start; - observed = input.End; - var copy = input; - if (copy.Length < 2) - { - return false; - } - - var fixedheader = copy.First.Span[0]; - if (!TryReadBodyLength(ref copy, out var bodyLength)) - { - return false; - } - - var bodySlice = copy.Slice(0, bodyLength); - packet = serializer.Deserialize(new ReceivedMqttPacket(fixedheader, new MqttPacketBodyReader(bodySlice.GetArray()))); - consumed = bodySlice.End; - observed = bodySlice.End; - return true; - } } } From bc20850fba2fbf13e9d60d748c2dd30ad2278c17 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Tue, 19 Jun 2018 20:52:31 +0200 Subject: [PATCH 16/18] Refactor serializer locking and thread instances. --- Build/MQTTnet.AspNetCore.nuspec | 6 ++-- Build/MQTTnet.Extensions.ManagedClient.nuspec | 19 +--------- Build/MQTTnet.Extensions.Rpc.nuspec | 19 +--------- Build/MQTTnet.nuspec | 2 +- Build/build.ps1 | 8 ++--- Build/upload.ps1 | 9 +++++ MQTTnet.sln | 2 ++ .../ManagedMqttClient.cs | 35 ++++++++----------- Source/MQTTnet/Adapter/MqttChannelAdapter.cs | 7 ++++ .../MQTTnet/Diagnostics/MqttNetChildLogger.cs | 2 +- .../Implementations/MqttTcpChannel.Uwp.cs | 22 +++++++++--- Source/MQTTnet/MQTTnet.csproj | 5 ++- .../Serializer/MqttPacketSerializer.cs | 35 ++++++------------- Source/MQTTnet/Serializer/MqttPacketWriter.cs | 6 ++-- .../Server/MqttClientSessionsManager.cs | 2 +- .../MainPage.xaml | 31 ++++++++++------ .../MainPage.xaml.cs | 20 ++++++++++- 17 files changed, 120 insertions(+), 110 deletions(-) create mode 100644 Build/upload.ps1 diff --git a/Build/MQTTnet.AspNetCore.nuspec b/Build/MQTTnet.AspNetCore.nuspec index 9f59562..669f205 100644 --- a/Build/MQTTnet.AspNetCore.nuspec +++ b/Build/MQTTnet.AspNetCore.nuspec @@ -13,11 +13,9 @@ * Updated to MQTTnet 2.8.0. Copyright Christian Kratky 2016-2018 - MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin + MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin - - - + diff --git a/Build/MQTTnet.Extensions.ManagedClient.nuspec b/Build/MQTTnet.Extensions.ManagedClient.nuspec index 714ed9e..6d7f7f6 100644 --- a/Build/MQTTnet.Extensions.ManagedClient.nuspec +++ b/Build/MQTTnet.Extensions.ManagedClient.nuspec @@ -15,24 +15,7 @@ Copyright Christian Kratky 2016-2018 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin - - - - - - - - - - - - - - - - - - + diff --git a/Build/MQTTnet.Extensions.Rpc.nuspec b/Build/MQTTnet.Extensions.Rpc.nuspec index 16a51c2..d7eddb0 100644 --- a/Build/MQTTnet.Extensions.Rpc.nuspec +++ b/Build/MQTTnet.Extensions.Rpc.nuspec @@ -15,24 +15,7 @@ Copyright Christian Kratky 2016-2018 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin - - - - - - - - - - - - - - - - - - + diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 5d58437..556957d 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -48,7 +48,7 @@ - + diff --git a/Build/build.ps1 b/Build/build.ps1 index 59f3112..0b6f074 100644 --- a/Build/build.ps1 +++ b/Build/build.ps1 @@ -36,8 +36,8 @@ if ($path) { Remove-Item .\NuGet -Force -Recurse -ErrorAction SilentlyContinue New-Item -ItemType Directory -Force -Path .\NuGet - .\NuGet.exe pack MQTTnet.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion - .\NuGet.exe pack MQTTnet.AspNetCore.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion - .\NuGet.exe pack MQTTnet.Extensions.Rpc.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion - .\NuGet.exe pack MQTTnet.Extensions.ManagedClient.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion + .\nuget.exe pack MQTTnet.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion + .\nuget.exe pack MQTTnet.AspNetCore.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion + .\nuget.exe pack MQTTnet.Extensions.Rpc.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion + .\nuget.exe pack MQTTnet.Extensions.ManagedClient.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion } \ No newline at end of file diff --git a/Build/upload.ps1 b/Build/upload.ps1 new file mode 100644 index 0000000..a7cb172 --- /dev/null +++ b/Build/upload.ps1 @@ -0,0 +1,9 @@ +param([string]$apiKey) + +$files = Get-ChildItem -Path ".\NuGet" -Filter "*.nupkg" +foreach ($file in $files) +{ + Write-Host "Uploading: " $file + + .\nuget.exe push $file.Fullname $apiKey -NoSymbols -Source https://api.nuget.org/v3/index.json +} \ No newline at end of file diff --git a/MQTTnet.sln b/MQTTnet.sln index 0e0e60c..a46c295 100644 --- a/MQTTnet.sln +++ b/MQTTnet.sln @@ -20,6 +20,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Build", "Build", "{67C28AC1 Build\MQTTnet.Extensions.ManagedClient.nuspec = Build\MQTTnet.Extensions.ManagedClient.nuspec Build\MQTTnet.Extensions.Rpc.nuspec = Build\MQTTnet.Extensions.Rpc.nuspec Build\MQTTnet.nuspec = Build\MQTTnet.nuspec + Build\upload.ps1 = Build\upload.ps1 EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{B3F60ECB-45BA-4C66-8903-8BB89CA67998}" @@ -74,6 +75,7 @@ Global {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|x86.Build.0 = Release|Any CPU {FF1F72D6-9524-4422-9497-3CC0002216ED}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {FF1F72D6-9524-4422-9497-3CC0002216ED}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FF1F72D6-9524-4422-9497-3CC0002216ED}.Debug|Any CPU.Deploy.0 = Debug|Any CPU {FF1F72D6-9524-4422-9497-3CC0002216ED}.Debug|ARM.ActiveCfg = Debug|ARM {FF1F72D6-9524-4422-9497-3CC0002216ED}.Debug|ARM.Build.0 = Debug|ARM {FF1F72D6-9524-4422-9497-3CC0002216ED}.Debug|ARM.Deploy.0 = Debug|ARM diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs index 9bd14b2..687b0cd 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs @@ -7,7 +7,6 @@ using System.Threading.Tasks; using MQTTnet.Client; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; -using MQTTnet.Internal; using MQTTnet.Protocol; namespace MQTTnet.Extensions.ManagedClient @@ -15,8 +14,7 @@ namespace MQTTnet.Extensions.ManagedClient public class ManagedMqttClient : IManagedMqttClient { private readonly BlockingCollection _messageQueue = new BlockingCollection(); - private readonly Dictionary _subscriptions = new Dictionary(); - private readonly AsyncLock _subscriptionsLock = new AsyncLock(); + private readonly ConcurrentDictionary _subscriptions = new ConcurrentDictionary(); private readonly List _unsubscriptions = new List(); private readonly IMqttClient _mqttClient; @@ -118,39 +116,36 @@ namespace MQTTnet.Extensions.ManagedClient _messageQueue.Add(applicationMessage); } - public async Task SubscribeAsync(IEnumerable topicFilters) + public Task SubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - using (await _subscriptionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + foreach (var topicFilter in topicFilters) { - foreach (var topicFilter in topicFilters) - { - _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; - _subscriptionsNotPushed = true; - } + _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + _subscriptionsNotPushed = true; } + + return Task.FromResult(0); } - public async Task UnsubscribeAsync(IEnumerable topics) + public Task UnsubscribeAsync(IEnumerable topics) { - using (await _subscriptionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + foreach (var topic in topics) { - foreach (var topic in topics) + if (_subscriptions.TryRemove(topic, out _)) { - if (_subscriptions.Remove(topic)) - { - _unsubscriptions.Add(topic); - _subscriptionsNotPushed = true; - } + _unsubscriptions.Add(topic); + _subscriptionsNotPushed = true; } } + + return Task.FromResult(0); } public void Dispose() { _messageQueue?.Dispose(); - _subscriptionsLock?.Dispose(); _connectionCancellationToken?.Dispose(); _publishingCancellationToken?.Dispose(); } @@ -289,7 +284,7 @@ namespace MQTTnet.Extensions.ManagedClient List subscriptions; List unsubscriptions; - using (await _subscriptionsLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + lock (_subscriptions) { subscriptions = _subscriptions.Select(i => new TopicFilter(i.Key, i.Value)).ToList(); diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 94e0f33..61912ef 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -18,6 +18,8 @@ namespace MQTTnet.Adapter private const uint ErrorOperationAborted = 0x800703E3; private const int ReadBufferSize = 4096; // TODO: Move buffer size to config + private readonly SemaphoreSlim _writerSemaphore = new SemaphoreSlim(1, 1); + private readonly IMqttNetChildLogger _logger; private readonly IMqttChannel _channel; @@ -88,6 +90,7 @@ namespace MQTTnet.Adapter public async Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { + await _writerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); try { _logger.Verbose("TX >>> {0}", packet); @@ -107,6 +110,10 @@ namespace MQTTnet.Adapter WrapException(exception); } + finally + { + _writerSemaphore.Release(); + } } public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) diff --git a/Source/MQTTnet/Diagnostics/MqttNetChildLogger.cs b/Source/MQTTnet/Diagnostics/MqttNetChildLogger.cs index 1ae5a9c..3733454 100644 --- a/Source/MQTTnet/Diagnostics/MqttNetChildLogger.cs +++ b/Source/MQTTnet/Diagnostics/MqttNetChildLogger.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Diagnostics public MqttNetChildLogger(IMqttNetLogger logger, string source) { - _logger = logger; + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _source = source; } diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs index d7a030a..2ccc7d9 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs @@ -40,7 +40,18 @@ namespace MQTTnet.Implementations public static Func> CustomIgnorableServerCertificateErrorsResolver { get; set; } - public string Endpoint => _socket?.Information?.RemoteAddress?.ToString(); // TODO: Check if contains also the port. + public string Endpoint + { + get + { + if (_socket?.Information != null) + { + return _socket.Information.RemoteAddress + ":" + _socket.Information.RemotePort; + } + + return null; + } + } public async Task ConnectAsync(CancellationToken cancellationToken) { @@ -81,10 +92,13 @@ namespace MQTTnet.Implementations return _readStream.ReadAsync(buffer, offset, count, cancellationToken); } - public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - await _writeStream.WriteAsync(buffer, offset, count, cancellationToken); - await _writeStream.FlushAsync(cancellationToken); + // In the write method only the internal buffer will be filled. So here is no + // async/await required. The real network transmit is done when calling the + // Flush method. + _writeStream.Write(buffer, offset, count); + return _writeStream.FlushAsync(cancellationToken); } public void Dispose() diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index 4afd09d..8364378 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -1,7 +1,7 @@  - netstandard1.3;netstandard2.0;net452;uap10.0 + netstandard1.3;netstandard2.0;net452;net461;uap10.0 netstandard1.3;netstandard2.0 MQTTnet MQTTnet @@ -62,4 +62,7 @@ + + + \ No newline at end of file diff --git a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs index 159ba89..537513d 100644 --- a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs +++ b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs @@ -11,8 +11,7 @@ namespace MQTTnet.Serializer { private const int FixedHeaderSize = 1; - [ThreadStatic] - private static MqttPacketWriter _packetWriter; + private readonly MqttPacketWriter _packetWriter = new MqttPacketWriter(); public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; @@ -20,14 +19,12 @@ namespace MQTTnet.Serializer { if (packet == null) throw new ArgumentNullException(nameof(packet)); - var packetWriter = InitializePacketWriter(); - // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) - packetWriter.Reset(); - packetWriter.Seek(5); + _packetWriter.Reset(); + _packetWriter.Seek(5); - var fixedHeader = SerializePacket(packet, packetWriter); - var remainingLength = packetWriter.Length - 5; + var fixedHeader = SerializePacket(packet, _packetWriter); + var remainingLength = _packetWriter.Length - 5; var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength); @@ -35,12 +32,12 @@ namespace MQTTnet.Serializer var headerOffset = 5 - headerSize; // Position cursor on correct offset on beginining of array (has leading 0x0) - packetWriter.Seek(headerOffset); - packetWriter.Write(fixedHeader); - packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); + _packetWriter.Seek(headerOffset); + _packetWriter.Write(fixedHeader); + _packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); - var buffer = packetWriter.GetBuffer(); - return new ArraySegment(buffer, headerOffset, packetWriter.Length - headerOffset); + var buffer = _packetWriter.GetBuffer(); + return new ArraySegment(buffer, headerOffset, _packetWriter.Length - headerOffset); } public MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket) @@ -76,7 +73,7 @@ namespace MQTTnet.Serializer public void FreeBuffer() { - InitializePacketWriter().FreeBuffer(); + _packetWriter.FreeBuffer(); } private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) @@ -101,16 +98,6 @@ namespace MQTTnet.Serializer } } - private static MqttPacketWriter InitializePacketWriter() - { - if (_packetWriter == null) - { - _packetWriter = new MqttPacketWriter(); - } - - return _packetWriter; - } - private static MqttBasePacket DeserializeUnsubAck(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); diff --git a/Source/MQTTnet/Serializer/MqttPacketWriter.cs b/Source/MQTTnet/Serializer/MqttPacketWriter.cs index a2c6f5b..ba535bf 100644 --- a/Source/MQTTnet/Serializer/MqttPacketWriter.cs +++ b/Source/MQTTnet/Serializer/MqttPacketWriter.cs @@ -12,6 +12,8 @@ namespace MQTTnet.Serializer /// public class MqttPacketWriter { + public static int MaxBufferSize = 4096; + private byte[] _buffer = new byte[128]; private int _position; @@ -117,12 +119,12 @@ namespace MQTTnet.Serializer // a lot and the size will never reduced. So this method tries to find a size which can be held in // memory for a long time without causing troubles. - if (_buffer.Length < 4096) + if (_buffer.Length < MaxBufferSize) { return; } - Array.Resize(ref _buffer, 4096); + Array.Resize(ref _buffer, MaxBufferSize); } private void EnsureAdditionalCapacity(int additionalCapacity) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 1cb301a..5937265 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -148,7 +148,7 @@ namespace MQTTnet.Server clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage); } } - catch (TaskCanceledException) + catch (OperationCanceledException) { } catch (Exception exception) diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml index 906d9b8..41617df 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml @@ -4,6 +4,8 @@ xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" + xmlns:server="using:MQTTnet.Server" + xmlns:interop="using:Windows.UI.Xaml.Interop" d:DesignHeight="800" d:DesignWidth="800" mc:Ignorable="d"> @@ -31,7 +33,7 @@ Keep alive interval: - + TCP WS @@ -142,6 +144,7 @@ Persist retained messages in JSON format Clear previously retained messages on startup + Allow persistent sessions @@ -149,11 +152,6 @@ - - - - - @@ -162,12 +160,23 @@ - + - - - - + + + + + + + + + + + + + diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs index 76f947d..387364a 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Concurrent; +using System.Collections.ObjectModel; using System.Text; using System.Threading.Tasks; using Windows.Security.Cryptography.Certificates; @@ -21,6 +22,7 @@ namespace MQTTnet.TestApp.UniversalWindows public sealed partial class MainPage { private readonly ConcurrentQueue _traceMessages = new ConcurrentQueue(); + private readonly ObservableCollection _sessions = new ObservableCollection(); private IMqttClient _mqttClient; private IMqttServer _mqttServer; @@ -306,6 +308,7 @@ namespace MQTTnet.TestApp.UniversalWindows var options = new MqttServerOptions(); options.DefaultEndpointOptions.Port = int.Parse(ServerPort.Text); options.Storage = storage; + options.EnablePersistentSessions = ServerAllowPersistentSessions.IsChecked == true; await _mqttServer.StartAsync(options); } @@ -374,10 +377,25 @@ namespace MQTTnet.TestApp.UniversalWindows private void ClearSessions(object sender, RoutedEventArgs e) { + _sessions.Clear(); } - private void RefreshSessions(object sender, RoutedEventArgs e) + private async void RefreshSessions(object sender, RoutedEventArgs e) { + if (_mqttServer == null) + { + return; + } + + var sessions = await _mqttServer.GetClientSessionsStatusAsync(); + _sessions.Clear(); + + foreach (var session in sessions) + { + _sessions.Add(session); + } + + ListViewSessions.DataContext = _sessions; } private async Task WikiCode() From 64b38487bc09a6a6f566e30df20b60bc39f1009f Mon Sep 17 00:00:00 2001 From: Jan Eggers Date: Wed, 20 Jun 2018 08:20:54 +0200 Subject: [PATCH 17/18] improve session manager --- .../Server/MqttClientSessionsManager.cs | 65 +++++++++++++------ 1 file changed, 46 insertions(+), 19 deletions(-) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 1cb301a..1b507d3 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -14,7 +15,11 @@ namespace MQTTnet.Server public class MqttClientSessionsManager { private readonly BlockingCollection _messageQueue = new BlockingCollection(); - private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); + + /// + /// manual locking dictionaries is faster than using concurrent dictionary + /// + private readonly Dictionary _sessions = new Dictionary(); private readonly CancellationToken _cancellationToken; @@ -43,12 +48,16 @@ namespace MQTTnet.Server public Task StopAsync() { - foreach (var session in _sessions) + lock (_sessions) { - session.Value.Stop(MqttClientDisconnectType.NotClean); - } + foreach (var session in _sessions) + { + session.Value.Stop(MqttClientDisconnectType.NotClean); + } - _sessions.Clear(); + _sessions.Clear(); + } + _messageQueue.Dispose(); return Task.FromResult(0); } @@ -60,10 +69,11 @@ namespace MQTTnet.Server public Task> GetClientStatusAsync() { var result = new List(); - foreach (var session in _sessions) + + foreach (var session in GetSessions()) { - var status = new MqttClientSessionStatus(this, session.Value); - session.Value.FillStatus(status); + var status = new MqttClientSessionStatus(this, session); + session.FillStatus(status); result.Add(status); } @@ -83,12 +93,15 @@ namespace MQTTnet.Server if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - if (!_sessions.TryGetValue(clientId, out var session)) + lock (_sessions) { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); + } - return session.SubscribeAsync(topicFilters); + return session.SubscribeAsync(topicFilters); + } } public Task UnsubscribeAsync(string clientId, IList topicFilters) @@ -96,17 +109,23 @@ namespace MQTTnet.Server if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - if (!_sessions.TryGetValue(clientId, out var session)) + lock (_sessions) { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); + } - return session.UnsubscribeAsync(topicFilters); + return session.UnsubscribeAsync(topicFilters); + } } public void DeleteSession(string clientId) { - _sessions.TryRemove(clientId, out _); + lock (_sessions) + { + _sessions.Remove(clientId); + } _logger.Verbose("Session for client '{0}' deleted.", clientId); } @@ -143,7 +162,7 @@ namespace MQTTnet.Server _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult(); } - foreach (var clientSession in _sessions.Values) + foreach (var clientSession in GetSessions()) { clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage); } @@ -158,6 +177,14 @@ namespace MQTTnet.Server } } + private List GetSessions() + { + lock (_sessions) + { + return _sessions.Values.ToList(); + } + } + private async Task RunSession(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) { var clientId = string.Empty; @@ -263,7 +290,7 @@ namespace MQTTnet.Server { if (connectPacket.CleanSession) { - _sessions.TryRemove(connectPacket.ClientId, out _); + _sessions.Remove(connectPacket.ClientId); clientSession.Stop(MqttClientDisconnectType.Clean); clientSession.Dispose(); From 0322660561ed93060385d83d9fa1f2d7e533aeb6 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Wed, 20 Jun 2018 20:07:25 +0200 Subject: [PATCH 18/18] Refactor async/await and ConcurrentDictionary usage. --- README.md | 8 +- .../MQTTnet.Extensions.Rpc/MqttRpcClient.cs | 4 +- Source/MQTTnet/Client/MqttClient.cs | 13 +- .../MQTTnet/Internal/AsyncAutoResetEvent.cs | 33 +++-- Source/MQTTnet/Internal/AsyncLock.cs | 2 +- Source/MQTTnet/Serializer/MqttPacketReader.cs | 9 +- Source/MQTTnet/Server/IMqttClientSession.cs | 23 +++ .../Server/MqttClientKeepAliveMonitor.cs | 26 ++-- .../Server/MqttClientPendingPacketsQueue.cs | 58 +++++--- Source/MQTTnet/Server/MqttClientSession.cs | 136 +++++++++++------- .../Server/MqttClientSessionsManager.cs | 14 +- .../Server/MqttClientSubscriptionsManager.cs | 45 +++--- .../Server/MqttEnqueuedApplicationMessage.cs | 10 +- .../Server/MqttRetainedMessagesManager.cs | 65 +++++---- Source/MQTTnet/Server/MqttServer.cs | 10 +- .../MqttKeepAliveMonitorTests.cs | 82 ++++++++--- .../MqttSubscriptionsManagerTests.cs | 42 +----- .../PerformanceTest.cs | 43 +++++- Tests/MQTTnet.TestApp.NetCore/Program.cs | 14 +- Tests/MQTTnet.TestApp.NetCore/ServerTest.cs | 11 +- 20 files changed, 406 insertions(+), 242 deletions(-) create mode 100644 Source/MQTTnet/Server/IMqttClientSession.cs diff --git a/README.md b/README.md index 52735ca..d50f51c 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov * TLS 1.2 support for client and server (but not UWP servers) * Extensible communication channels (i.e. In-Memory, TCP, TCP+TLS, WS) * Lightweight (only the low level implementation of MQTT, no overhead) -* Performance optimized (processing ~60.000 messages / second)* +* Performance optimized (processing ~70.000 messages / second)* * Interfaces included for mocking and testing * Access to internal trace messages * Unit tested (~90 tests) @@ -50,14 +50,15 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov * .NET Standard 1.3+ * .NET Core 1.1+ * .NET Core App 1.1+ -* Universal Windows Platform (UWP) 10.0.10240+ (x86, x64, ARM, AnyCPU, Windows 10 IoT Core) * .NET Framework 4.5.2+ (x86, x64, AnyCPU) * Mono 5.2+ +* Universal Windows Platform (UWP) 10.0.10240+ (x86, x64, ARM, AnyCPU, Windows 10 IoT Core) * Xamarin.Android 7.5+ * Xamarin.iOS 10.14+ ## Supported MQTT versions +* 5.0.0 (planned) * 3.1.1 * 3.1.0 @@ -79,8 +80,7 @@ This library is used in the following projects: * MQTT Client Rx (Wrapper for Reactive Extensions, ) * MQTT Tester (MQTT client test app for [Android](https://play.google.com/store/apps/details?id=com.liveowl.mqtttester) and [iOS](https://itunes.apple.com/us/app/mqtt-tester/id1278621826?mt=8)) -* Wirehome (Open Source Home Automation system for .NET, ) - +* HA4IoT (Open Source Home Automation system for .NET, ) If you use this library and want to see your project here please let me know. diff --git a/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs b/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs index 717f74e..ebf66ec 100644 --- a/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs +++ b/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs @@ -82,11 +82,11 @@ namespace MQTTnet.Extensions.Rpc timeoutCts.Cancel(false); return result; } - catch (TaskCanceledException taskCanceledException) + catch (OperationCanceledException exception) { if (timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) { - throw new MqttCommunicationTimedOutException(taskCanceledException); + throw new MqttCommunicationTimedOutException(exception); } else { diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index e9362f6..ccedb5a 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -271,21 +271,16 @@ namespace MQTTnet.Client private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - if (cancellationToken.IsCancellationRequested) - { - throw new TaskCanceledException(); - } + cancellationToken.ThrowIfCancellationRequested(); _sendTracker.Restart(); + return _adapter.SendPacketAsync(packet, cancellationToken); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket { - if (cancellationToken.IsCancellationRequested) - { - throw new TaskCanceledException(); - } + cancellationToken.ThrowIfCancellationRequested(); _sendTracker.Restart(); @@ -524,7 +519,7 @@ namespace MQTTnet.Client { await task.ConfigureAwait(false); } - catch (TaskCanceledException) + catch (OperationCanceledException) { } } diff --git a/Source/MQTTnet/Internal/AsyncAutoResetEvent.cs b/Source/MQTTnet/Internal/AsyncAutoResetEvent.cs index 132e7e0..a90b98e 100644 --- a/Source/MQTTnet/Internal/AsyncAutoResetEvent.cs +++ b/Source/MQTTnet/Internal/AsyncAutoResetEvent.cs @@ -11,8 +11,10 @@ namespace MQTTnet.Internal private readonly LinkedList> _waiters = new LinkedList>(); private bool _isSignaled; - public AsyncAutoResetEvent() : this(false) - { } + public AsyncAutoResetEvent() + : this(false) + { + } public AsyncAutoResetEvent(bool signaled) { @@ -58,27 +60,24 @@ namespace MQTTnet.Internal } var winner = await Task.WhenAny(tcs.Task, Task.Delay(timeout, cancellationToken)).ConfigureAwait(false); - if (winner == tcs.Task) + var taskWasSignaled = winner == tcs.Task; + if (taskWasSignaled) { - // The task was signaled. return true; } - else + + // We timed-out; remove our reference to the task. + // This is an O(n) operation since waiters is a LinkedList. + lock (_waiters) { - // We timed-out; remove our reference to the task. - // This is an O(n) operation since waiters is a LinkedList. - lock (_waiters) + _waiters.Remove(tcs); + + if (winner.Status == TaskStatus.Canceled) { - _waiters.Remove(tcs); - if (winner.Status == TaskStatus.Canceled) - { - throw new OperationCanceledException(cancellationToken); - } - else - { - throw new TimeoutException(); - } + throw new OperationCanceledException(cancellationToken); } + + throw new TimeoutException(); } } diff --git a/Source/MQTTnet/Internal/AsyncLock.cs b/Source/MQTTnet/Internal/AsyncLock.cs index 87878fa..87571c2 100644 --- a/Source/MQTTnet/Internal/AsyncLock.cs +++ b/Source/MQTTnet/Internal/AsyncLock.cs @@ -17,7 +17,7 @@ namespace MQTTnet.Internal public Task LockAsync(CancellationToken cancellationToken) { - Task wait = _semaphore.WaitAsync(cancellationToken); + var wait = _semaphore.WaitAsync(cancellationToken); return wait.IsCompleted ? _releaser : wait.ContinueWith((_, state) => (IDisposable)state, diff --git a/Source/MQTTnet/Serializer/MqttPacketReader.cs b/Source/MQTTnet/Serializer/MqttPacketReader.cs index 5bbed52..826747c 100644 --- a/Source/MQTTnet/Serializer/MqttPacketReader.cs +++ b/Source/MQTTnet/Serializer/MqttPacketReader.cs @@ -29,11 +29,7 @@ namespace MQTTnet.Serializer var bytesRead = await channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); if (bytesRead <= 0) { - if (cancellationToken.IsCancellationRequested) - { - throw new TaskCanceledException(); - } - + cancellationToken.ThrowIfCancellationRequested(); ExceptionHelper.ThrowGracefulSocketClose(); } @@ -59,6 +55,8 @@ namespace MQTTnet.Serializer while ((encodedByte & 128) != 0) { + cancellationToken.ThrowIfCancellationRequested(); + // Here the async/await pattern is not used becuase the overhead of context switches // is too big for reading 1 byte in a row. We expect that the remaining data was sent // directly after the initial bytes. If the client disconnects just in this moment we @@ -83,6 +81,7 @@ namespace MQTTnet.Serializer var readCount = channel.ReadAsync(buffer, 0, 1, cancellationToken).GetAwaiter().GetResult(); if (readCount <= 0) { + cancellationToken.ThrowIfCancellationRequested(); ExceptionHelper.ThrowGracefulSocketClose(); } diff --git a/Source/MQTTnet/Server/IMqttClientSession.cs b/Source/MQTTnet/Server/IMqttClientSession.cs new file mode 100644 index 0000000..a94ad18 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttClientSession.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using MQTTnet.Adapter; +using MQTTnet.Packets; + +namespace MQTTnet.Server +{ + public interface IMqttClientSession : IDisposable + { + string ClientId { get; } + void FillStatus(MqttClientSessionStatus status); + + void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket); + void ClearPendingApplicationMessages(); + + Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter); + void Stop(MqttClientDisconnectType disconnectType); + + Task SubscribeAsync(IList topicFilters); + Task UnsubscribeAsync(IList topicFilters); + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs index a413088..b362861 100644 --- a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs +++ b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs @@ -12,19 +12,17 @@ namespace MQTTnet.Server private readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch(); private readonly Stopwatch _lastNonKeepAlivePacketReceivedTracker = new Stopwatch(); + private readonly IMqttClientSession _clientSession; private readonly IMqttNetChildLogger _logger; - private readonly string _clientId; - private readonly Action _callback; - + private bool _isPaused; - private Task _workerTask; - - public MqttClientKeepAliveMonitor(string clientId, Action callback, IMqttNetChildLogger logger) + + public MqttClientKeepAliveMonitor(IMqttClientSession clientSession, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); - _clientId = clientId; - _callback = callback; + _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); + _logger = logger.CreateChildLogger(nameof(MqttClientKeepAliveMonitor)); } @@ -39,7 +37,7 @@ namespace MQTTnet.Server return; } - _workerTask = Task.Run(() => RunAsync(keepAlivePeriod, cancellationToken), cancellationToken); + Task.Run(() => RunAsync(keepAlivePeriod, cancellationToken), cancellationToken); } public void Pause() @@ -74,9 +72,9 @@ namespace MQTTnet.Server // Values described here: [MQTT-3.1.2-24]. if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds > keepAlivePeriod * 1.5D) { - _logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientId); - _callback(); - + _logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientSession.ClientId); + _clientSession.Stop(MqttClientDisconnectType.NotClean); + return; } @@ -88,11 +86,11 @@ namespace MQTTnet.Server } catch (Exception exception) { - _logger.Error(exception, "Client '{0}': Unhandled exception while checking keep alive timeouts.", _clientId); + _logger.Error(exception, "Client '{0}': Unhandled exception while checking keep alive timeouts.", _clientSession.ClientId); } finally { - _logger.Verbose("Client {0}: Stopped checking keep alive timeout.", _clientId); + _logger.Verbose("Client {0}: Stopped checking keep alive timeout.", _clientSession.ClientId); } } } diff --git a/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs index 7a55c0b..503d992 100644 --- a/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs +++ b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs @@ -1,5 +1,5 @@ using System; -using System.Collections.Concurrent; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -13,13 +13,13 @@ namespace MQTTnet.Server { public class MqttClientPendingPacketsQueue : IDisposable { + private readonly Queue _queue = new Queue(); private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); + private readonly IMqttServerOptions _options; private readonly MqttClientSession _clientSession; private readonly IMqttNetChildLogger _logger; - private ConcurrentQueue _queue = new ConcurrentQueue(); - public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); @@ -29,7 +29,16 @@ namespace MQTTnet.Server _logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue)); } - public int Count => _queue.Count; + public int Count + { + get + { + lock (_queue) + { + return _queue.Count; + } + } + } public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken) { @@ -42,25 +51,29 @@ namespace MQTTnet.Server Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken); } - + public void Enqueue(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); - if (_queue.Count >= _options.MaxPendingMessagesPerClient) + lock (_queue) { - if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropNewMessage) + if (_queue.Count >= _options.MaxPendingMessagesPerClient) { - return; - } + if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropNewMessage) + { + return; + } - if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage) - { - _queue.TryDequeue(out _); + if (_options.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage) + { + _queue.Dequeue(); + } } + + _queue.Enqueue(packet); } - _queue.Enqueue(packet); _queueAutoResetEvent.Set(); _logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId); @@ -68,13 +81,14 @@ namespace MQTTnet.Server public void Clear() { - var newQueue = new ConcurrentQueue(); - Interlocked.Exchange(ref _queue, newQueue); + lock (_queue) + { + _queue.Clear(); + } } public void Dispose() { - } private async Task SendQueuedPacketsAsync(IMqttChannelAdapter adapter, CancellationToken cancellationToken) @@ -100,13 +114,17 @@ namespace MQTTnet.Server MqttBasePacket packet = null; try { - if (_queue.IsEmpty) + lock (_queue) { - await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false); + if (_queue.Count > 0) + { + packet = _queue.Dequeue(); + } } - if (!_queue.TryDequeue(out packet)) + if (packet == null) { + await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false); return; } @@ -115,7 +133,7 @@ namespace MQTTnet.Server return; } - await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); + adapter.SendPacketAsync(packet, cancellationToken).GetAwaiter().GetResult(); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); } diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 460bea4..429fdb4 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -12,7 +12,7 @@ using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttClientSession : IDisposable + public class MqttClientSession : IMqttClientSession { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); @@ -47,7 +47,7 @@ namespace MQTTnet.Server _logger = logger.CreateChildLogger(nameof(MqttClientSession)); - _keepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, () => Stop(MqttClientDisconnectType.NotClean), _logger); + _keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger); _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server); _pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger); } @@ -89,7 +89,7 @@ namespace MQTTnet.Server if (packet != null) { _keepAliveMonitor.PacketReceived(packet); - await ProcessReceivedPacketAsync(adapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false); + ProcessReceivedPacket(adapter, packet, _cancellationTokenSource.Token); } } } @@ -102,7 +102,7 @@ namespace MQTTnet.Server { if (exception is MqttCommunicationClosedGracefullyException) { - _logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId); ; + _logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId); } else { @@ -113,7 +113,7 @@ namespace MQTTnet.Server { _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); } - + Stop(MqttClientDisconnectType.NotClean); } finally @@ -123,7 +123,7 @@ namespace MQTTnet.Server _adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; _adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; } - + _adapter = null; _cancellationTokenSource?.Dispose(); @@ -149,7 +149,7 @@ namespace MQTTnet.Server if (_willMessage != null && !_wasCleanDisconnect) { - _sessionsManager.EnqueueApplicationMessage(this, _willMessage); + _sessionsManager.EnqueueApplicationMessage(this, _willMessage.ToPublishPacket()); } _willMessage = null; @@ -160,18 +160,24 @@ namespace MQTTnet.Server } } - public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket) { - if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); - var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(applicationMessage); + var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(publishPacket.Topic, publishPacket.QualityOfServiceLevel); if (!checkSubscriptionsResult.IsSubscribed) { return; } - var publishPacket = applicationMessage.ToPublishPacket(); - publishPacket.QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel; + publishPacket = new MqttPublishPacket + { + Topic = publishPacket.Topic, + Payload = publishPacket.Payload, + QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel, + Retain = false, + Dup = false + }; if (publishPacket.QualityOfServiceLevel > 0) { @@ -184,15 +190,19 @@ namespace MQTTnet.Server senderClientSession?.ClientId, ClientId, publishPacket.ToApplicationMessage()); - + _options.ClientMessageQueueInterceptor?.Invoke(context); if (!context.AcceptEnqueue || context.ApplicationMessage == null) { return; } + + publishPacket.Topic = context.ApplicationMessage.Topic; + publishPacket.Payload = context.ApplicationMessage.Payload; + publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel; } - + _pendingPacketsQueue.Enqueue(publishPacket); } @@ -233,21 +243,29 @@ namespace MQTTnet.Server _cancellationTokenSource?.Dispose(); } - private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken) + private void ProcessReceivedPacket(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken) { if (packet is MqttPublishPacket publishPacket) { - return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken); + HandleIncomingPublishPacket(adapter, publishPacket, cancellationToken); + return; } if (packet is MqttPingReqPacket) { - return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); + adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken).GetAwaiter().GetResult(); + return; } if (packet is MqttPubRelPacket pubRelPacket) { - return HandleIncomingPubRelPacketAsync(adapter, pubRelPacket, cancellationToken); + var responsePacket = new MqttPubCompPacket + { + PacketIdentifier = pubRelPacket.PacketIdentifier + }; + + adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult(); + return; } if (packet is MqttPubRecPacket pubRecPacket) @@ -257,40 +275,41 @@ namespace MQTTnet.Server PacketIdentifier = pubRecPacket.PacketIdentifier }; - return adapter.SendPacketAsync(responsePacket, cancellationToken); + adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult(); + return; } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) { - // Discard message. - return Task.FromResult(0); + return; } if (packet is MqttSubscribePacket subscribePacket) { - return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken); + HandleIncomingSubscribePacket(adapter, subscribePacket, cancellationToken); + return; } if (packet is MqttUnsubscribePacket unsubscribePacket) { - return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken); + HandleIncomingUnsubscribePacket(adapter, unsubscribePacket, cancellationToken); + return; } if (packet is MqttDisconnectPacket) { Stop(MqttClientDisconnectType.Clean); - return Task.FromResult(0); + return; } if (packet is MqttConnectPacket) { Stop(MqttClientDisconnectType.NotClean); - return Task.FromResult(0); + return; } _logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); Stop(MqttClientDisconnectType.NotClean); - return Task.FromResult(0); } private void EnqueueSubscribedRetainedMessages(ICollection topicFilters) @@ -298,14 +317,14 @@ namespace MQTTnet.Server var retainedMessages = _retainedMessagesManager.GetSubscribedMessages(topicFilters); foreach (var applicationMessage in retainedMessages) { - EnqueueApplicationMessage(null, applicationMessage); + EnqueueApplicationMessage(null, applicationMessage.ToPublishPacket()); } } - private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) + private void HandleIncomingSubscribePacket(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket); - await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); + adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).GetAwaiter().GetResult(); if (subscribeResult.CloseConnection) { @@ -316,30 +335,30 @@ namespace MQTTnet.Server EnqueueSubscribedRetainedMessages(subscribePacket.TopicFilters); } - private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) + private void HandleIncomingUnsubscribePacket(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); - return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); + adapter.SendPacketAsync(unsubscribeResult, cancellationToken).GetAwaiter().GetResult(); } - private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) + private void HandleIncomingPublishPacket(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - var applicationMessage = publishPacket.ToApplicationMessage(); - - switch (applicationMessage.QualityOfServiceLevel) + switch (publishPacket.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: { - _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); - return Task.FromResult(0); + HandleIncomingPublishPacketWithQoS0(publishPacket); + break; } case MqttQualityOfServiceLevel.AtLeastOnce: { - return HandleIncomingPublishPacketWithQoS1(adapter, applicationMessage, publishPacket, cancellationToken); + HandleIncomingPublishPacketWithQoS1(adapter, publishPacket, cancellationToken); + break; } case MqttQualityOfServiceLevel.ExactlyOnce: { - return HandleIncomingPublishPacketWithQoS2(adapter, applicationMessage, publishPacket, cancellationToken); + HandleIncomingPublishPacketWithQoS2(adapter, publishPacket, cancellationToken); + break; } default: { @@ -348,27 +367,40 @@ namespace MQTTnet.Server } } - private Task HandleIncomingPublishPacketWithQoS1(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) + private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket) { - _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); - - var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - return adapter.SendPacketAsync(response, cancellationToken); + _sessionsManager.EnqueueApplicationMessage(this, publishPacket); } - private Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) + private void HandleIncomingPublishPacketWithQoS1( + IMqttChannelAdapter adapter, + MqttPublishPacket publishPacket, + CancellationToken cancellationToken) { - // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) - _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); + _sessionsManager.EnqueueApplicationMessage(this, publishPacket); + + var response = new MqttPubAckPacket + { + PacketIdentifier = publishPacket.PacketIdentifier + }; - var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - return adapter.SendPacketAsync(response, cancellationToken); + adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult(); } - private static Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) + private void HandleIncomingPublishPacketWithQoS2( + IMqttChannelAdapter adapter, + MqttPublishPacket publishPacket, + CancellationToken cancellationToken) { - var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }; - return adapter.SendPacketAsync(response, cancellationToken); + // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) + _sessionsManager.EnqueueApplicationMessage(this, publishPacket); + + var response = new MqttPubRecPacket + { + PacketIdentifier = publishPacket.PacketIdentifier + }; + + adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult(); } private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 5937265..8188da8 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -6,6 +6,7 @@ using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; +using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -41,7 +42,7 @@ namespace MQTTnet.Server Task.Factory.StartNew(() => ProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); } - public Task StopAsync() + public void Stop() { foreach (var session in _sessions) { @@ -49,7 +50,6 @@ namespace MQTTnet.Server } _sessions.Clear(); - return Task.FromResult(0); } public Task StartSession(IMqttChannelAdapter clientAdapter) @@ -71,11 +71,11 @@ namespace MQTTnet.Server return Task.FromResult((IList)result); } - public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket) { - if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); - _messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken); + _messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, publishPacket), _cancellationToken); } public Task SubscribeAsync(string clientId, IList topicFilters) @@ -118,7 +118,7 @@ namespace MQTTnet.Server { var enqueuedApplicationMessage = _messageQueue.Take(cancellationToken); var sender = enqueuedApplicationMessage.Sender; - var applicationMessage = enqueuedApplicationMessage.ApplicationMessage; + var applicationMessage = enqueuedApplicationMessage.PublishPacket.ToApplicationMessage(); var interceptorContext = InterceptApplicationMessage(sender, applicationMessage); if (interceptorContext != null) @@ -145,7 +145,7 @@ namespace MQTTnet.Server foreach (var clientSession in _sessions.Values) { - clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage); + clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage.ToPublishPacket()); } } catch (OperationCanceledException) diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index 5f067ff..83ac033 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using MQTTnet.Packets; @@ -9,7 +8,7 @@ namespace MQTTnet.Server { public class MqttClientSubscriptionsManager { - private readonly ConcurrentDictionary _subscriptions = new ConcurrentDictionary(); + private readonly Dictionary _subscriptions = new Dictionary(); private readonly IMqttServerOptions _options; private readonly MqttServer _server; private readonly string _clientId; @@ -54,7 +53,11 @@ namespace MQTTnet.Server if (interceptorContext.AcceptSubscription) { - _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + lock (_subscriptions) + { + _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + } + _server.OnClientSubscribedTopic(_clientId, topicFilter); } } @@ -66,10 +69,14 @@ namespace MQTTnet.Server { if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket)); - foreach (var topicFilter in unsubscribePacket.TopicFilters) + lock (_subscriptions) { - _subscriptions.TryRemove(topicFilter, out _); - _server.OnClientUnsubscribedTopic(_clientId, topicFilter); + foreach (var topicFilter in unsubscribePacket.TopicFilters) + { + _subscriptions.Remove(topicFilter); + + _server.OnClientUnsubscribedTopic(_clientId, topicFilter); + } } return new MqttUnsubAckPacket @@ -78,19 +85,21 @@ namespace MQTTnet.Server }; } - public CheckSubscriptionsResult CheckSubscriptions(MqttApplicationMessage applicationMessage) + public CheckSubscriptionsResult CheckSubscriptions(string topic, MqttQualityOfServiceLevel qosLevel) { - if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - var qosLevels = new HashSet(); - foreach (var subscription in _subscriptions) + + lock (_subscriptions) { - if (!MqttTopicFilterComparer.IsMatch(applicationMessage.Topic, subscription.Key)) + foreach (var subscription in _subscriptions) { - continue; - } + if (!MqttTopicFilterComparer.IsMatch(topic, subscription.Key)) + { + continue; + } - qosLevels.Add(subscription.Value); + qosLevels.Add(subscription.Value); + } } if (qosLevels.Count == 0) @@ -101,7 +110,7 @@ namespace MQTTnet.Server }; } - return CreateSubscriptionResult(applicationMessage, qosLevels); + return CreateSubscriptionResult(qosLevel, qosLevels); } private static MqttSubscribeReturnCode ConvertToMaximumQoS(MqttQualityOfServiceLevel qualityOfServiceLevel) @@ -122,12 +131,12 @@ namespace MQTTnet.Server return interceptorContext; } - private static CheckSubscriptionsResult CreateSubscriptionResult(MqttApplicationMessage applicationMessage, HashSet subscribedQoSLevels) + private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet subscribedQoSLevels) { MqttQualityOfServiceLevel effectiveQoS; - if (subscribedQoSLevels.Contains(applicationMessage.QualityOfServiceLevel)) + if (subscribedQoSLevels.Contains(qosLevel)) { - effectiveQoS = applicationMessage.QualityOfServiceLevel; + effectiveQoS = qosLevel; } else if (subscribedQoSLevels.Count == 1) { diff --git a/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs index 20ff2fe..37591d8 100644 --- a/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs +++ b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs @@ -1,15 +1,17 @@ -namespace MQTTnet.Server +using MQTTnet.Packets; + +namespace MQTTnet.Server { public class MqttEnqueuedApplicationMessage { - public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) + public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttPublishPacket publishPacket) { Sender = sender; - ApplicationMessage = applicationMessage; + PublishPacket = publishPacket; } public MqttClientSession Sender { get; } - public MqttApplicationMessage ApplicationMessage { get; } + public MqttPublishPacket PublishPacket { get; } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs index 9fe454c..86c321f 100644 --- a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs +++ b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -9,7 +8,8 @@ namespace MQTTnet.Server { public class MqttRetainedMessagesManager { - private readonly ConcurrentDictionary _messages = new ConcurrentDictionary(); + private readonly Dictionary _messages = new Dictionary(); + private readonly IMqttNetChildLogger _logger; private readonly IMqttServerOptions _options; @@ -31,10 +31,13 @@ namespace MQTTnet.Server { var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); - _messages.Clear(); - foreach (var retainedMessage in retainedMessages) + lock (_messages) { - _messages[retainedMessage.Topic] = retainedMessage; + _messages.Clear(); + foreach (var retainedMessage in retainedMessages) + { + _messages[retainedMessage.Topic] = retainedMessage; + } } } catch (Exception exception) @@ -61,17 +64,20 @@ namespace MQTTnet.Server { var retainedMessages = new List(); - foreach (var retainedMessage in _messages.Values) + lock (_messages) { - foreach (var topicFilter in topicFilters) + foreach (var retainedMessage in _messages.Values) { - if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic)) + foreach (var topicFilter in topicFilters) { - continue; - } + if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic)) + { + continue; + } - retainedMessages.Add(retainedMessage); - break; + retainedMessages.Add(retainedMessage); + break; + } } } @@ -82,28 +88,31 @@ namespace MQTTnet.Server { var saveIsRequired = false; - if (applicationMessage.Payload?.Length == 0) - { - saveIsRequired = _messages.TryRemove(applicationMessage.Topic, out _); - _logger.Info("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic); - } - else + lock (_messages) { - if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage)) + if (applicationMessage.Payload?.Length == 0) { - _messages[applicationMessage.Topic] = applicationMessage; - saveIsRequired = true; + saveIsRequired = _messages.Remove(applicationMessage.Topic); + _logger.Info("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic); } else { - if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0])) + if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage)) { _messages[applicationMessage.Topic] = applicationMessage; saveIsRequired = true; } - } + else + { + if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0])) + { + _messages[applicationMessage.Topic] = applicationMessage; + saveIsRequired = true; + } + } - _logger.Info("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic); + _logger.Info("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic); + } } if (!saveIsRequired) @@ -113,7 +122,13 @@ namespace MQTTnet.Server if (saveIsRequired && _options.Storage != null) { - await _options.Storage.SaveRetainedMessagesAsync(_messages.Values.ToList()).ConfigureAwait(false); + List messages; + lock (_messages) + { + messages = _messages.Values.ToList(); + } + + await _options.Storage.SaveRetainedMessagesAsync(messages).ConfigureAwait(false); } } } diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index 38e631e..69850d5 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -5,6 +5,7 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; +using MQTTnet.Internal; namespace MQTTnet.Server { @@ -65,7 +66,7 @@ namespace MQTTnet.Server if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started."); - _clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage); + _clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage.ToPublishPacket()); return Task.FromResult(0); } @@ -104,22 +105,23 @@ namespace MQTTnet.Server } _cancellationTokenSource.Cancel(false); - _cancellationTokenSource.Dispose(); - + foreach (var adapter in _adapters) { adapter.ClientAccepted -= OnClientAccepted; await adapter.StopAsync().ConfigureAwait(false); } - await _clientSessionsManager.StopAsync().ConfigureAwait(false); + _clientSessionsManager.Stop(); _logger.Info("Stopped."); Stopped?.Invoke(this, EventArgs.Empty); } finally { + _cancellationTokenSource?.Dispose(); _cancellationTokenSource = null; + _retainedMessagesManager = null; _clientSessionsManager = null; } diff --git a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs index 9563a41..267b7bc 100644 --- a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs @@ -1,5 +1,9 @@ -using System.Threading; +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Packets; using MQTTnet.Server; @@ -12,39 +16,31 @@ namespace MQTTnet.Core.Tests [TestMethod] public void KeepAlive_Timeout() { - var timeoutCalledCount = 0; + var clientSession = new TestClientSession(); + var monitor = new MqttClientKeepAliveMonitor(clientSession, new MqttNetLogger().CreateChildLogger()); - var monitor = new MqttClientKeepAliveMonitor(string.Empty, delegate - { - timeoutCalledCount++; - }, new MqttNetLogger().CreateChildLogger("")); - - Assert.AreEqual(0, timeoutCalledCount); + Assert.AreEqual(0, clientSession.StopCalledCount); monitor.Start(1, CancellationToken.None); - Assert.AreEqual(0, timeoutCalledCount); + Assert.AreEqual(0, clientSession.StopCalledCount); Thread.Sleep(2000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification. - Assert.AreEqual(1, timeoutCalledCount); + Assert.AreEqual(1, clientSession.StopCalledCount); } [TestMethod] public void KeepAlive_NoTimeout() { - var timeoutCalledCount = 0; - - var monitor = new MqttClientKeepAliveMonitor(string.Empty, delegate - { - timeoutCalledCount++; - }, new MqttNetLogger().CreateChildLogger("")); + var clientSession = new TestClientSession(); + var monitor = new MqttClientKeepAliveMonitor(clientSession, new MqttNetLogger().CreateChildLogger()); - Assert.AreEqual(0, timeoutCalledCount); + Assert.AreEqual(0, clientSession.StopCalledCount); monitor.Start(1, CancellationToken.None); - Assert.AreEqual(0, timeoutCalledCount); + Assert.AreEqual(0, clientSession.StopCalledCount); // Simulate traffic. Thread.Sleep(1000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification. @@ -53,11 +49,57 @@ namespace MQTTnet.Core.Tests monitor.PacketReceived(new MqttPublishPacket()); Thread.Sleep(1000); - Assert.AreEqual(0, timeoutCalledCount); + Assert.AreEqual(0, clientSession.StopCalledCount); Thread.Sleep(2000); - Assert.AreEqual(1, timeoutCalledCount); + Assert.AreEqual(1, clientSession.StopCalledCount); + } + + private class TestClientSession : IMqttClientSession + { + public string ClientId { get; } + + public int StopCalledCount { get; set; } + + public void FillStatus(MqttClientSessionStatus status) + { + throw new NotSupportedException(); + } + + public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket) + { + throw new NotSupportedException(); + } + + public void ClearPendingApplicationMessages() + { + throw new NotSupportedException(); + } + + public Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) + { + throw new NotSupportedException(); + } + + public void Stop(MqttClientDisconnectType disconnectType) + { + StopCalledCount++; + } + + public Task SubscribeAsync(IList topicFilters) + { + throw new NotSupportedException(); + } + + public Task UnsubscribeAsync(IList topicFilters) + { + throw new NotSupportedException(); + } + + public void Dispose() + { + } } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs index aa84dd3..268e5fe 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs @@ -20,13 +20,7 @@ namespace MQTTnet.Core.Tests sm.Subscribe(sp); - var pp = new MqttApplicationMessage - { - Topic = "A/B/C", - QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce - }; - - var result = sm.CheckSubscriptions(pp); + var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce); Assert.IsTrue(result.IsSubscribed); Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtMostOnce); } @@ -41,13 +35,7 @@ namespace MQTTnet.Core.Tests sm.Subscribe(sp); - var pp = new MqttApplicationMessage - { - Topic = "A/B/C", - QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce - }; - - var result = sm.CheckSubscriptions(pp); + var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtMostOnce); } @@ -63,13 +51,7 @@ namespace MQTTnet.Core.Tests sm.Subscribe(sp); - var pp = new MqttApplicationMessage - { - Topic = "A/B/C", - QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce - }; - - var result = sm.CheckSubscriptions(pp); + var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); Assert.AreEqual(result.QualityOfServiceLevel, MqttQualityOfServiceLevel.AtLeastOnce); } @@ -84,13 +66,7 @@ namespace MQTTnet.Core.Tests sm.Subscribe(sp); - var pp = new MqttApplicationMessage - { - Topic = "A/B/X", - QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce - }; - - Assert.IsFalse(sm.CheckSubscriptions(pp).IsSubscribed); + Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } [TestMethod] @@ -103,19 +79,13 @@ namespace MQTTnet.Core.Tests sm.Subscribe(sp); - var pp = new MqttApplicationMessage - { - Topic = "A/B/C", - QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce - }; - - Assert.IsTrue(sm.CheckSubscriptions(pp).IsSubscribed); + Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); var up = new MqttUnsubscribePacket(); up.TopicFilters.Add("A/B/C"); sm.Unsubscribe(up); - Assert.IsFalse(sm.CheckSubscriptions(pp).IsSubscribed); + Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } } } diff --git a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs index 752feb5..3f4a1d5 100644 --- a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs @@ -12,7 +12,48 @@ namespace MQTTnet.TestApp.NetCore { public static class PerformanceTest { - public static void Run() + public static void RunClientOnly() + { + try + { + var options = new MqttClientOptions + { + ChannelOptions = new MqttClientTcpOptions + { + Server = "127.0.0.1" + }, + CleanSession = true + }; + + var client = new MqttFactory().CreateMqttClient(); + client.ConnectAsync(options).GetAwaiter().GetResult(); + + var message = CreateMessage(); + var stopwatch = new Stopwatch(); + + for (var i = 0; i < 10; i++) + { + var sentMessagesCount = 0; + + stopwatch.Restart(); + while (stopwatch.ElapsedMilliseconds < 1000) + { + client.PublishAsync(message).GetAwaiter().GetResult(); + sentMessagesCount++; + } + + Console.WriteLine($"Sending {sentMessagesCount} messages per second. #" + (i + 1)); + + GC.Collect(); + } + } + catch (Exception exception) + { + Console.WriteLine(exception); + } + } + + public static void RunClientAndServer() { try { diff --git a/Tests/MQTTnet.TestApp.NetCore/Program.cs b/Tests/MQTTnet.TestApp.NetCore/Program.cs index 25302c7..908844f 100644 --- a/Tests/MQTTnet.TestApp.NetCore/Program.cs +++ b/Tests/MQTTnet.TestApp.NetCore/Program.cs @@ -22,6 +22,8 @@ namespace MQTTnet.TestApp.NetCore Console.WriteLine("5 = Start public broker test"); Console.WriteLine("6 = Start server & client"); Console.WriteLine("7 = Client flow test"); + Console.WriteLine("8 = Start performance test (client only)"); + Console.WriteLine("9 = Start server (no trace)"); var pressedKey = Console.ReadKey(true); if (pressedKey.KeyChar == '1') @@ -34,7 +36,7 @@ namespace MQTTnet.TestApp.NetCore } else if (pressedKey.KeyChar == '3') { - PerformanceTest.Run(); + PerformanceTest.RunClientAndServer(); return; } else if (pressedKey.KeyChar == '4') @@ -53,6 +55,16 @@ namespace MQTTnet.TestApp.NetCore { Task.Run(ClientFlowTest.RunAsync); } + else if (pressedKey.KeyChar == '8') + { + PerformanceTest.RunClientOnly(); + return; + } + else if (pressedKey.KeyChar == '9') + { + ServerTest.RunEmptyServer(); + return; + } Thread.Sleep(Timeout.Infinite); } diff --git a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs index 80a8e5a..12fa1aa 100644 --- a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs @@ -8,12 +8,19 @@ namespace MQTTnet.TestApp.NetCore { public static class ServerTest { + public static void RunEmptyServer() + { + var mqttServer = new MqttFactory().CreateMqttServer(); + mqttServer.StartAsync(new MqttServerOptions()).GetAwaiter().GetResult(); + + Console.WriteLine("Press any key to exit."); + Console.ReadLine(); + } + public static async Task RunAsync() { try { - MqttNetConsoleLogger.ForwardToConsole(); - var options = new MqttServerOptions { ConnectionValidator = p =>