diff --git a/Build/MQTTnet.Extensions.ManagedClient.nuspec b/Build/MQTTnet.Extensions.ManagedClient.nuspec index e681554..714ed9e 100644 --- a/Build/MQTTnet.Extensions.ManagedClient.nuspec +++ b/Build/MQTTnet.Extensions.ManagedClient.nuspec @@ -48,6 +48,8 @@ - + + + \ No newline at end of file diff --git a/Build/MQTTnet.Extensions.Rpc.nuspec b/Build/MQTTnet.Extensions.Rpc.nuspec index f8667c4..16a51c2 100644 --- a/Build/MQTTnet.Extensions.Rpc.nuspec +++ b/Build/MQTTnet.Extensions.Rpc.nuspec @@ -48,6 +48,8 @@ - + + + \ No newline at end of file diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 25c31fb..5d58437 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -69,6 +69,8 @@ - + + + \ No newline at end of file diff --git a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs index e026e7b..43e5c8c 100644 --- a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs @@ -20,7 +20,7 @@ namespace MQTTnet.Adapter Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken); - Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken); + Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken); Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); } diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 4fc681a..7db9a2a 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -40,52 +40,84 @@ namespace MQTTnet.Adapter public event EventHandler ReadingPacketStarted; public event EventHandler ReadingPacketCompleted; - public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + public async Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); - _logger.Verbose("Connecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => - Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); + try + { + _logger.Verbose("Connecting [Timeout={0}]", timeout); + + await Internal.TaskExtensions + .TimeoutAfterAsync(ct => _channel.ConnectAsync(ct), timeout, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } } - public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) + public async Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); - _logger.Verbose("Disconnecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => - Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, cancellationToken)); + try + { + _logger.Verbose("Disconnecting [Timeout={0}]", timeout); + + await Internal.TaskExtensions + .TimeoutAfterAsync(ct => _channel.DisconnectAsync(), timeout, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } } - public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken) + public async Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - return ExecuteAndWrapExceptionAsync(() => + try { - _logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); + _logger.Verbose("TX >>> {0}", packet); var packetData = PacketSerializer.Serialize(packet); - return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( - packetData.Array, - packetData.Offset, - packetData.Count, - ct), timeout, cancellationToken); - }); + await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false); + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } } public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); - MqttBasePacket packet = null; - await ExecuteAndWrapExceptionAsync(async () => + try { ReceivedMqttPacket receivedMqttPacket; if (timeout > TimeSpan.Zero) { - receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); + receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfterAsync(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); } else { @@ -94,19 +126,30 @@ namespace MQTTnet.Adapter if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) { - return; + return null; } - packet = PacketSerializer.Deserialize(receivedMqttPacket); + var packet = PacketSerializer.Deserialize(receivedMqttPacket); if (packet == null) { throw new MqttProtocolViolationException("Received malformed packet."); } _logger.Verbose("RX <<< {0}", packet); - }).ConfigureAwait(false); + + return packet; + } + catch (Exception exception) + { + if (IsWrappedException(exception)) + { + throw; + } + + WrapException(exception); + } - return packet; + return null; } private async Task ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) @@ -138,7 +181,9 @@ namespace MQTTnet.Adapter chunkSize = bytesLeft; } - var readBytes = await channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken) .ConfigureAwait(false); + // async/await is not used to avoid the overhead of context switches. We assume that the reamining data + // has been sent from the sender directly after the initial bytes. + var readBytes = channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken).GetAwaiter().GetResult(); if (readBytes <= 0) { ExceptionHelper.ThrowGracefulSocketClose(); @@ -155,42 +200,6 @@ namespace MQTTnet.Adapter } } - private static async Task ExecuteAndWrapExceptionAsync(Func action) - { - try - { - await action().ConfigureAwait(false); - } - catch (Exception exception) - { - if (exception is TaskCanceledException || - exception is OperationCanceledException || - exception is MqttCommunicationTimedOutException || - exception is MqttCommunicationException) - { - throw; - } - - if (exception is IOException && exception.InnerException is SocketException socketException) - { - if (socketException.SocketErrorCode == SocketError.ConnectionAborted) - { - throw new OperationCanceledException(); - } - } - - if (exception is COMException comException) - { - if ((uint)comException.HResult == ErrorOperationAborted) - { - throw new OperationCanceledException(); - } - } - - throw new MqttCommunicationException(exception); - } - } - public void Dispose() { _isDisposed = true; @@ -205,5 +214,34 @@ namespace MQTTnet.Adapter throw new ObjectDisposedException(nameof(MqttChannelAdapter)); } } + + private static bool IsWrappedException(Exception exception) + { + return exception is TaskCanceledException || + exception is OperationCanceledException || + exception is MqttCommunicationTimedOutException || + exception is MqttCommunicationException; + } + + private static void WrapException(Exception exception) + { + if (exception is IOException && exception.InnerException is SocketException socketException) + { + if (socketException.SocketErrorCode == SocketError.ConnectionAborted) + { + throw new OperationCanceledException(); + } + } + + if (exception is COMException comException) + { + if ((uint)comException.HResult == ErrorOperationAborted) + { + throw new OperationCanceledException(); + } + } + + throw new MqttCommunicationException(exception); + } } } diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 79a0ecf..e9362f6 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -17,7 +17,7 @@ namespace MQTTnet.Client { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly Stopwatch _sendTracker = new Stopwatch(); - private readonly SemaphoreSlim _disconnectLock = new SemaphoreSlim(1, 1); + private readonly object _disconnectLock = new object(); private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly IMqttClientAdapterFactory _adapterFactory; @@ -215,7 +215,7 @@ namespace MQTTnet.Client private async Task DisconnectInternalAsync(Task sender, Exception exception) { - await InitiateDisconnectAsync().ConfigureAwait(false); + InitiateDisconnect(); var clientWasConnected = IsConnected; IsConnected = false; @@ -249,25 +249,23 @@ namespace MQTTnet.Client } } - private async Task InitiateDisconnectAsync() + private void InitiateDisconnect() { - await _disconnectLock.WaitAsync().ConfigureAwait(false); - try + lock (_disconnectLock) { - if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested) + try { - return; - } + if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested) + { + return; + } - _cancellationTokenSource.Cancel(false); - } - catch (Exception adapterException) - { - _logger.Warning(adapterException, "Error while initiating disconnect."); - } - finally - { - _disconnectLock.Release(); + _cancellationTokenSource.Cancel(false); + } + catch (Exception adapterException) + { + _logger.Warning(adapterException, "Error while initiating disconnect."); + } } } @@ -279,7 +277,7 @@ namespace MQTTnet.Client } _sendTracker.Restart(); - return _adapter.SendPacketAsync(_options.CommunicationTimeout, packet, cancellationToken); + return _adapter.SendPacketAsync(packet, cancellationToken); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket @@ -300,8 +298,8 @@ namespace MQTTnet.Client var packetAwaiter = _packetDispatcher.AddPacketAwaiter(identifier); try { - await _adapter.SendPacketAsync(_options.CommunicationTimeout, requestPacket, cancellationToken).ConfigureAwait(false); - var respone = await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); + await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false); + var respone = await Internal.TaskExtensions.TimeoutAfterAsync(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); return (TResponsePacket)respone; } diff --git a/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs b/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs index 5258276..efbf08b 100644 --- a/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs +++ b/Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs @@ -10,6 +10,8 @@ return "net452"; #elif NET461 return "net461"; +#elif NET472 + return "net472"; #elif NETSTANDARD1_3 return "netstandard1.3"; #elif NETSTANDARD2_0 diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 20829be..c48a06f 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -1,4 +1,4 @@ -#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 +#if !WINDOWS_UWP using System; using System.Net.Security; using System.Net.Sockets; diff --git a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs index 8783e9c..eeefb34 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs @@ -1,4 +1,4 @@ -#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 +#if !WINDOWS_UWP using System; using System.Collections.Generic; using System.Net.Sockets; diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index fccd77f..b77d9b3 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -1,4 +1,4 @@ -#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 +#if !WINDOWS_UWP using System; using System.Net; using System.Net.Security; @@ -76,7 +76,8 @@ namespace MQTTnet.Implementations await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, SslProtocols.Tls12, false).ConfigureAwait(false); } - _logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {_addressFamily}'."); + var protocol = _addressFamily == AddressFamily.InterNetwork ? "ipv4" : "ipv6"; + _logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {protocol}'."); var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger); ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); @@ -102,7 +103,7 @@ namespace MQTTnet.Implementations { _socket?.Dispose(); -#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461 +#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461 || NET472 _tlsCertificate?.Dispose(); #endif } diff --git a/Source/MQTTnet/Internal/TaskExtensions.cs b/Source/MQTTnet/Internal/TaskExtensions.cs index 288ac0b..1356d97 100644 --- a/Source/MQTTnet/Internal/TaskExtensions.cs +++ b/Source/MQTTnet/Internal/TaskExtensions.cs @@ -7,7 +7,7 @@ namespace MQTTnet.Internal { public static class TaskExtensions { - public static async Task TimeoutAfter(Func action, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task TimeoutAfterAsync(Func action, TimeSpan timeout, CancellationToken cancellationToken) { if (action == null) throw new ArgumentNullException(nameof(action)); @@ -31,7 +31,7 @@ namespace MQTTnet.Internal } } - public static async Task TimeoutAfter(Func> action, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task TimeoutAfterAsync(Func> action, TimeSpan timeout, CancellationToken cancellationToken) { if (action == null) throw new ArgumentNullException(nameof(action)); diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index 8364378..4afd09d 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -1,7 +1,7 @@  - netstandard1.3;netstandard2.0;net452;net461;uap10.0 + netstandard1.3;netstandard2.0;net452;uap10.0 netstandard1.3;netstandard2.0 MQTTnet MQTTnet @@ -62,7 +62,4 @@ - - - \ No newline at end of file diff --git a/Source/MQTTnet/Serializer/ByteReader.cs b/Source/MQTTnet/Serializer/ByteReader.cs deleted file mode 100644 index 461bba3..0000000 --- a/Source/MQTTnet/Serializer/ByteReader.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System; - -namespace MQTTnet.Serializer -{ - public class ByteReader - { - private readonly int _source; - private int _index; - - public ByteReader(int source) - { - _source = source; - } - - public bool Read() - { - if (_index >= 8) - { - throw new InvalidOperationException("End of byte reached."); - } - - var result = ((1 << _index) & _source) > 0; - _index++; - return result; - } - - public int Read(int count) - { - if (_index + count > 8) - { - throw new InvalidOperationException("End of byte will be reached."); - } - - var result = 0; - for (var i = 0; i < count; i++) - { - if (((1 << _index) & _source) > 0) - { - result |= 1 << i; - } - - _index++; - } - - return result; - } - } -} diff --git a/Source/MQTTnet/Serializer/ByteWriter.cs b/Source/MQTTnet/Serializer/ByteWriter.cs deleted file mode 100644 index 9ae2156..0000000 --- a/Source/MQTTnet/Serializer/ByteWriter.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System; - -namespace MQTTnet.Serializer -{ - public class ByteWriter - { - private int _index; - private int _byte; - - public byte Value => (byte)_byte; - - public void Write(int @byte, int count) - { - for (var i = 0; i < count; i++) - { - var value = ((1 << i) & @byte) > 0; - Write(value); - } - } - - public void Write(bool bit) - { - if (_index >= 8) - { - throw new InvalidOperationException("End of the byte reached."); - } - - if (bit) - { - _byte |= 1 << _index; - } - - _index++; - } - } -} diff --git a/Source/MQTTnet/Serializer/MqttPacketReader.cs b/Source/MQTTnet/Serializer/MqttPacketReader.cs index 7ed918f..d50fae5 100644 --- a/Source/MQTTnet/Serializer/MqttPacketReader.cs +++ b/Source/MQTTnet/Serializer/MqttPacketReader.cs @@ -12,6 +12,8 @@ namespace MQTTnet.Serializer { // The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. // So in all cases at least 2 bytes must be read for a complete MQTT packet. + // async/await is used here because the next packet is received in a couple of minutes so the performance + // impact is acceptable according to a useless waiting thread. var buffer = new byte[2]; var totalBytesRead = 0; @@ -37,11 +39,11 @@ namespace MQTTnet.Serializer return new MqttFixedHeader(buffer[0], 0); } - var bodyLength = await ReadBodyLengthAsync(channel, buffer[1], cancellationToken).ConfigureAwait(false); + var bodyLength = ReadBodyLength(channel, buffer[1], cancellationToken); return new MqttFixedHeader(buffer[0], bodyLength); } - private static async Task ReadBodyLengthAsync(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) + private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. var multiplier = 128; @@ -50,7 +52,11 @@ namespace MQTTnet.Serializer while ((encodedByte & 128) != 0) { - encodedByte = await ReadByteAsync(channel, cancellationToken).ConfigureAwait(false); + // Here the async/await pattern is not used becuase the overhead of context switches + // is too big for reading 1 byte in a row. We expect that the remaining data was sent + // directly after the initial bytes. If the client disconnects just in this moment we + // will get an exception anyway. + encodedByte = ReadByteAsync(channel, cancellationToken).GetAwaiter().GetResult(); value += (byte)(encodedByte & 127) * multiplier; if (multiplier > 128 * 128 * 128) diff --git a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs index 2acbfbb..811acf7 100644 --- a/Source/MQTTnet/Serializer/MqttPacketSerializer.cs +++ b/Source/MQTTnet/Serializer/MqttPacketSerializer.cs @@ -2,7 +2,6 @@ using MQTTnet.Packets; using MQTTnet.Protocol; using System; -using System.IO; using System.Linq; using MQTTnet.Adapter; @@ -18,57 +17,46 @@ namespace MQTTnet.Serializer { if (packet == null) throw new ArgumentNullException(nameof(packet)); - using (var stream = new MemoryStream(128)) - { - // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) - stream.Seek(5, SeekOrigin.Begin); + var packetWriter = new MqttPacketWriter(); - var fixedHeader = SerializePacket(packet, stream); - var remainingLength = (int)stream.Length - 5; + // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) + packetWriter.Seek(5); - var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength); + var fixedHeader = SerializePacket(packet, packetWriter); + var remainingLength = packetWriter.Length - 5; - var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; - var headerOffset = 5 - headerSize; + var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength); - // Position cursor on correct offset on beginining of array (has leading 0x0) - stream.Seek(headerOffset, SeekOrigin.Begin); - stream.WriteByte(fixedHeader); - stream.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); + var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; + var headerOffset = 5 - headerSize; -#if NET461 || NET452 || NETSTANDARD2_0 - var buffer = stream.GetBuffer(); - return new ArraySegment(buffer, headerOffset, (int)stream.Length - headerOffset); -#else - if (stream.TryGetBuffer(out var segment)) - { - return new ArraySegment(segment.Array, headerOffset, segment.Count - headerOffset); - } + // Position cursor on correct offset on beginining of array (has leading 0x0) + packetWriter.Seek(headerOffset); + packetWriter.Write(fixedHeader); + packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); - var buffer = stream.ToArray(); - return new ArraySegment(buffer, headerOffset, buffer.Length - headerOffset); -#endif - } + var buffer = packetWriter.GetBuffer(); + return new ArraySegment(buffer, headerOffset, packetWriter.Length - headerOffset); } - private byte SerializePacket(MqttBasePacket packet, Stream stream) + private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) { switch (packet) { - case MqttConnectPacket connectPacket: return Serialize(connectPacket, stream); - case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, stream); + case MqttConnectPacket connectPacket: return Serialize(connectPacket, packetWriter); + case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, packetWriter); case MqttDisconnectPacket _: return SerializeEmptyPacket(MqttControlPacketType.Disconnect); case MqttPingReqPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingReq); case MqttPingRespPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingResp); - case MqttPublishPacket publishPacket: return Serialize(publishPacket, stream); - case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, stream); - case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, stream); - case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, stream); - case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, stream); - case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, stream); - case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, stream); - case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, stream); - case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, stream); + case MqttPublishPacket publishPacket: return Serialize(publishPacket, packetWriter); + case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, packetWriter); + case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, packetWriter); + case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, packetWriter); + case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, packetWriter); + case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, packetWriter); + case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, packetWriter); + case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, packetWriter); + case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, packetWriter); default: throw new MqttProtocolViolationException("Packet type invalid."); } } @@ -195,10 +183,9 @@ namespace MQTTnet.Serializer var body = receivedMqttPacket.Body; ThrowIfBodyIsEmpty(body); - var fixedHeader = new ByteReader(receivedMqttPacket.FixedHeader); - var retain = fixedHeader.Read(); - var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2); - var dup = fixedHeader.Read(); + var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; + var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); + var dup = (receivedMqttPacket.FixedHeader & 0x3) > 0; var topic = body.ReadStringWithLengthPrefix(); @@ -253,8 +240,8 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException($"Protocol name ({protocolName}) is not supported."); } - var connectFlags = new ByteReader(body.ReadByte()); - if (connectFlags.Read()) + var connectFlags = body.ReadByte(); + if ((connectFlags & 0x1) > 0) { throw new MqttProtocolViolationException("The first bit of the Connect Flags must be set to 0."); } @@ -262,14 +249,14 @@ namespace MQTTnet.Serializer var packet = new MqttConnectPacket { ProtocolVersion = protocolVersion, - CleanSession = connectFlags.Read() + CleanSession = (connectFlags & 0x2) > 0 }; - var willFlag = connectFlags.Read(); - var willQoS = connectFlags.Read(2); - var willRetain = connectFlags.Read(); - var passwordFlag = connectFlags.Read(); - var usernameFlag = connectFlags.Read(); + var willFlag = (connectFlags & 0x4) > 0; + var willQoS = (connectFlags & 0x18) >> 3; + var willRetain = (connectFlags & 0x20) > 0; + var passwordFlag = (connectFlags & 0x40) > 0; + var usernameFlag = (connectFlags & 0x80) > 0; packet.KeepAlivePeriod = body.ReadUInt16(); packet.ClientId = body.ReadStringWithLengthPrefix(); @@ -322,11 +309,11 @@ namespace MQTTnet.Serializer var packet = new MqttConnAckPacket(); - var firstByteReader = new ByteReader(body.ReadByte()); - + var acknowledgeFlags = body.ReadByte(); + if (ProtocolVersion == MqttProtocolVersion.V311) { - packet.IsSessionPresent = firstByteReader.Read(); + packet.IsSessionPresent = (acknowledgeFlags & 0x1) > 0; } packet.ConnectReturnCode = (MqttConnectReturnCode)body.ReadByte(); @@ -344,119 +331,129 @@ namespace MQTTnet.Serializer } } + // ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local private static void ValidatePublishPacket(MqttPublishPacket packet) { - if (packet == null) throw new ArgumentNullException(nameof(packet)); - if (packet.QualityOfServiceLevel == 0 && packet.Dup) { throw new MqttProtocolViolationException("Dup flag must be false for QoS 0 packets [MQTT-3.3.1-2]."); } } - private byte Serialize(MqttConnectPacket packet, Stream stream) + private byte Serialize(MqttConnectPacket packet, MqttPacketWriter packetWriter) { ValidateConnectPacket(packet); // Write variable header if (ProtocolVersion == MqttProtocolVersion.V311) { - stream.WriteWithLengthPrefix("MQTT"); - stream.WriteByte(4); // 3.1.2.2 Protocol Level 4 + packetWriter.WriteWithLengthPrefix("MQTT"); + packetWriter.Write(4); // 3.1.2.2 Protocol Level 4 } else { - stream.WriteWithLengthPrefix("MQIsdp"); - stream.WriteByte(3); // Protocol Level 3 + packetWriter.WriteWithLengthPrefix("MQIsdp"); + packetWriter.Write(3); // Protocol Level 3 } - var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags - connectFlags.Write(false); // Reserved - connectFlags.Write(packet.CleanSession); - connectFlags.Write(packet.WillMessage != null); - - if (packet.WillMessage != null) + byte connectFlags = 0x0; + if (packet.CleanSession) { - connectFlags.Write((int)packet.WillMessage.QualityOfServiceLevel, 2); - connectFlags.Write(packet.WillMessage.Retain); + connectFlags |= 0x2; } - else + + if (packet.WillMessage != null) { - connectFlags.Write(0, 2); - connectFlags.Write(false); - } + connectFlags |= 0x4; + connectFlags |= (byte)((byte)packet.WillMessage.QualityOfServiceLevel << 3); + if (packet.WillMessage.Retain) + { + connectFlags |= 0x20; + } + } + if (packet.Password != null && packet.Username == null) { throw new MqttProtocolViolationException("If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22]."); } - connectFlags.Write(packet.Password != null); - connectFlags.Write(packet.Username != null); + if (packet.Password != null) + { + connectFlags |= 0x40; + } - stream.Write(connectFlags); - stream.Write(packet.KeepAlivePeriod); - stream.WriteWithLengthPrefix(packet.ClientId); + if (packet.Username != null) + { + connectFlags |= 0x80; + } + + packetWriter.Write(connectFlags); + packetWriter.Write(packet.KeepAlivePeriod); + packetWriter.WriteWithLengthPrefix(packet.ClientId); if (packet.WillMessage != null) { - stream.WriteWithLengthPrefix(packet.WillMessage.Topic); - stream.WriteWithLengthPrefix(packet.WillMessage.Payload); + packetWriter.WriteWithLengthPrefix(packet.WillMessage.Topic); + packetWriter.WriteWithLengthPrefix(packet.WillMessage.Payload); } if (packet.Username != null) { - stream.WriteWithLengthPrefix(packet.Username); + packetWriter.WriteWithLengthPrefix(packet.Username); } if (packet.Password != null) { - stream.WriteWithLengthPrefix(packet.Password); + packetWriter.WriteWithLengthPrefix(packet.Password); } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - private byte Serialize(MqttConnAckPacket packet, Stream stream) + private byte Serialize(MqttConnAckPacket packet, MqttPacketWriter packetWriter) { if (ProtocolVersion == MqttProtocolVersion.V310) { - stream.WriteByte(0); + packetWriter.Write(0); } else if (ProtocolVersion == MqttProtocolVersion.V311) { - var connectAcknowledgeFlags = new ByteWriter(); - connectAcknowledgeFlags.Write(packet.IsSessionPresent); - - stream.Write(connectAcknowledgeFlags); + byte connectAcknowledgeFlags = 0x0; + if (packet.IsSessionPresent) + { + connectAcknowledgeFlags |= 0x1; + } + + packetWriter.Write(connectAcknowledgeFlags); } else { throw new MqttProtocolViolationException("Protocol version not supported."); } - stream.WriteByte((byte)packet.ConnectReturnCode); + packetWriter.Write((byte)packet.ConnectReturnCode); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); } - private static byte Serialize(MqttPubRelPacket packet, Stream stream) + private static byte Serialize(MqttPubRelPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubRel packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } - private static byte Serialize(MqttPublishPacket packet, Stream stream) + private static byte Serialize(MqttPublishPacket packet, MqttPacketWriter packetWriter) { ValidatePublishPacket(packet); - stream.WriteWithLengthPrefix(packet.Topic); + packetWriter.WriteWithLengthPrefix(packet.Topic); if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { @@ -465,7 +462,7 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Publish packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); } else { @@ -477,7 +474,7 @@ namespace MQTTnet.Serializer if (packet.Payload?.Length > 0) { - stream.Write(packet.Payload, 0, packet.Payload.Length); + packetWriter.Write(packet.Payload, 0, packet.Payload.Length); } byte fixedHeader = 0; @@ -497,43 +494,43 @@ namespace MQTTnet.Serializer return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); } - private static byte Serialize(MqttPubAckPacket packet, Stream stream) + private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubAck packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } - private static byte Serialize(MqttPubRecPacket packet, Stream stream) + private static byte Serialize(MqttPubRecPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubRec packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } - private static byte Serialize(MqttPubCompPacket packet, Stream stream) + private static byte Serialize(MqttPubCompPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubComp packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } - private static byte Serialize(MqttSubscribePacket packet, Stream stream) + private static byte Serialize(MqttSubscribePacket packet, MqttPacketWriter packetWriter) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); @@ -542,41 +539,41 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Subscribe packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Count > 0) { foreach (var topicFilter in packet.TopicFilters) { - stream.WriteWithLengthPrefix(topicFilter.Topic); - stream.WriteByte((byte)topicFilter.QualityOfServiceLevel); + packetWriter.WriteWithLengthPrefix(topicFilter.Topic); + packetWriter.Write((byte)topicFilter.QualityOfServiceLevel); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); } - private static byte Serialize(MqttSubAckPacket packet, Stream stream) + private static byte Serialize(MqttSubAckPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("SubAck packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); if (packet.SubscribeReturnCodes?.Any() == true) { foreach (var packetSubscribeReturnCode in packet.SubscribeReturnCodes) { - stream.WriteByte((byte)packetSubscribeReturnCode); + packetWriter.Write((byte)packetSubscribeReturnCode); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); } - private static byte Serialize(MqttUnsubscribePacket packet, Stream stream) + private static byte Serialize(MqttUnsubscribePacket packet, MqttPacketWriter packetWriter) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); @@ -585,27 +582,27 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Unsubscribe packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Any() == true) { foreach (var topicFilter in packet.TopicFilters) { - stream.WriteWithLengthPrefix(topicFilter); + packetWriter.WriteWithLengthPrefix(topicFilter); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static byte Serialize(MqttUnsubAckPacket packet, Stream stream) + private static byte Serialize(MqttUnsubAckPacket packet, MqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("UnsubAck packet has no packet identifier."); } - stream.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } @@ -614,6 +611,7 @@ namespace MQTTnet.Serializer return MqttPacketWriter.BuildFixedHeader(type); } + // ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local private static void ThrowIfBodyIsEmpty(MqttPacketBodyReader body) { if (body == null || body.Length == 0) diff --git a/Source/MQTTnet/Serializer/MqttPacketWriter.cs b/Source/MQTTnet/Serializer/MqttPacketWriter.cs index 2cc9d7b..c0c49fc 100644 --- a/Source/MQTTnet/Serializer/MqttPacketWriter.cs +++ b/Source/MQTTnet/Serializer/MqttPacketWriter.cs @@ -1,12 +1,23 @@ using System; -using System.IO; using System.Text; using MQTTnet.Protocol; namespace MQTTnet.Serializer { - public static class MqttPacketWriter + /// + /// This is a custom implementation of a memory stream which provides only MQTTnet relevant features. + /// The goal is to avoid lots of argument checks like in the original stream. The growth rule is the + /// same as for the original MemoryStream in .net. Also this implementation allows accessing the internal + /// buffer for all platforms and .net framework versions (which is not available at the regular MemoryStream). + /// + public class MqttPacketWriter { + private byte[] _buffer = new byte[128]; + + private int _position; + + public int Length { get; private set; } + public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0) { var fixedHeader = (int)packetType << 4; @@ -14,33 +25,6 @@ namespace MQTTnet.Serializer return (byte)fixedHeader; } - public static void Write(this Stream stream, ushort value) - { - var buffer = BitConverter.GetBytes(value); - stream.WriteByte(buffer[1]); - stream.WriteByte(buffer[0]); - } - - public static void Write(this Stream stream, ByteWriter value) - { - if (value == null) throw new ArgumentNullException(nameof(value)); - - stream.WriteByte(value.Value); - } - - public static void WriteWithLengthPrefix(this Stream stream, string value) - { - stream.WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); - } - - public static void WriteWithLengthPrefix(this Stream stream, byte[] value) - { - var length = (ushort)value.Length; - - stream.Write(length); - stream.Write(value, 0, length); - } - public static ArraySegment EncodeRemainingLength(int length) { // write the encoded remaining length right aligned on the 4 byte buffer @@ -69,5 +53,91 @@ namespace MQTTnet.Serializer return new ArraySegment(buffer, 0, bufferOffset); } + + public void WriteWithLengthPrefix(string value) + { + WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); + } + + public void WriteWithLengthPrefix(byte[] value) + { + EnsureAdditionalCapacity(value.Length + 2); + + Write((ushort)value.Length); + Write(value, 0, value.Length); + } + + public void Write(byte @byte) + { + EnsureAdditionalCapacity(1); + + _buffer[_position] = @byte; + IncreasePostition(1); + } + + public void Write(ushort value) + { + EnsureAdditionalCapacity(2); + + _buffer[_position] = (byte)(value >> 8); + IncreasePostition(1); + _buffer[_position] = (byte)value; + IncreasePostition(1); + } + + public void Write(byte[] array, int offset, int count) + { + EnsureAdditionalCapacity(count); + + Array.Copy(array, offset, _buffer, _position, count); + IncreasePostition(count); + } + + public void Seek(int offset) + { + EnsureCapacity(offset); + _position = offset; + } + + public byte[] GetBuffer() + { + return _buffer; + } + + private void EnsureAdditionalCapacity(int additionalCapacity) + { + var freeSpace = _buffer.Length - _position; + if (freeSpace >= additionalCapacity) + { + return; + } + + EnsureCapacity(additionalCapacity - freeSpace); + } + + private void EnsureCapacity(int capacity) + { + if (_buffer.Length >= capacity) + { + return; + } + + var newBufferLength = _buffer.Length; + while (newBufferLength < capacity) + { + newBufferLength *= 2; + } + + Array.Resize(ref _buffer, newBufferLength); + } + + private void IncreasePostition(int length) + { + _position += length; + if (_position > Length) + { + Length = _position; + } + } } } diff --git a/Source/MQTTnet/Server/MqttClientPendingMessagesQueue.cs b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs similarity index 93% rename from Source/MQTTnet/Server/MqttClientPendingMessagesQueue.cs rename to Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs index ff43e6f..7a55c0b 100644 --- a/Source/MQTTnet/Server/MqttClientPendingMessagesQueue.cs +++ b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs @@ -11,7 +11,7 @@ using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttClientPendingMessagesQueue : IDisposable + public class MqttClientPendingPacketsQueue : IDisposable { private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); private readonly IMqttServerOptions _options; @@ -20,13 +20,13 @@ namespace MQTTnet.Server private ConcurrentQueue _queue = new ConcurrentQueue(); - public MqttClientPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) + public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _options = options ?? throw new ArgumentNullException(nameof(options)); _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); - _logger = logger.CreateChildLogger(nameof(MqttClientPendingMessagesQueue)); + _logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue)); } public int Count => _queue.Count; @@ -115,7 +115,7 @@ namespace MQTTnet.Server return; } - await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, packet, cancellationToken).ConfigureAwait(false); + await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); } diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 28e97f0..460bea4 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -18,7 +18,7 @@ namespace MQTTnet.Server private readonly MqttRetainedMessagesManager _retainedMessagesManager; private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; - private readonly MqttClientPendingMessagesQueue _pendingMessagesQueue; + private readonly MqttClientPendingPacketsQueue _pendingPacketsQueue; private readonly MqttClientSubscriptionsManager _subscriptionsManager; private readonly MqttClientSessionsManager _sessionsManager; @@ -49,7 +49,7 @@ namespace MQTTnet.Server _keepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, () => Stop(MqttClientDisconnectType.NotClean), _logger); _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server); - _pendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger); + _pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger); } public string ClientId { get; } @@ -60,7 +60,7 @@ namespace MQTTnet.Server status.IsConnected = _adapter != null; status.Endpoint = _adapter?.Endpoint; status.ProtocolVersion = _adapter?.PacketSerializer?.ProtocolVersion; - status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count; + status.PendingApplicationMessagesCount = _pendingPacketsQueue.Count; status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived; status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; } @@ -80,7 +80,7 @@ namespace MQTTnet.Server _wasCleanDisconnect = false; _willMessage = connectPacket.WillMessage; - _pendingMessagesQueue.Start(adapter, _cancellationTokenSource.Token); + _pendingPacketsQueue.Start(adapter, _cancellationTokenSource.Token); _keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token); while (!_cancellationTokenSource.IsCancellationRequested) @@ -149,13 +149,10 @@ namespace MQTTnet.Server if (_willMessage != null && !_wasCleanDisconnect) { - _sessionsManager.StartDispatchApplicationMessage(this, _willMessage); + _sessionsManager.EnqueueApplicationMessage(this, _willMessage); } _willMessage = null; - - ////_pendingMessagesQueue.WaitForCompletion(); - ////_keepAliveMonitor.WaitForCompletion(); } finally { @@ -196,7 +193,7 @@ namespace MQTTnet.Server } } - _pendingMessagesQueue.Enqueue(publishPacket); + _pendingPacketsQueue.Enqueue(publishPacket); } public Task SubscribeAsync(IList topicFilters) @@ -226,12 +223,12 @@ namespace MQTTnet.Server public void ClearPendingApplicationMessages() { - _pendingMessagesQueue.Clear(); + _pendingPacketsQueue.Clear(); } public void Dispose() { - _pendingMessagesQueue?.Dispose(); + _pendingPacketsQueue?.Dispose(); _cancellationTokenSource?.Dispose(); } @@ -245,7 +242,7 @@ namespace MQTTnet.Server if (packet is MqttPingReqPacket) { - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, new MqttPingRespPacket(), cancellationToken); + return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); } if (packet is MqttPubRelPacket pubRelPacket) @@ -260,7 +257,7 @@ namespace MQTTnet.Server PacketIdentifier = pubRecPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, responsePacket, cancellationToken); + return adapter.SendPacketAsync(responsePacket, cancellationToken); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) @@ -308,7 +305,7 @@ namespace MQTTnet.Server private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket); - await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); + await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); if (subscribeResult.CloseConnection) { @@ -322,7 +319,7 @@ namespace MQTTnet.Server private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, unsubscribeResult, cancellationToken); + return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); } private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) @@ -333,7 +330,7 @@ namespace MQTTnet.Server { case MqttQualityOfServiceLevel.AtMostOnce: { - _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); + _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); return Task.FromResult(0); } case MqttQualityOfServiceLevel.AtLeastOnce: @@ -353,25 +350,25 @@ namespace MQTTnet.Server private Task HandleIncomingPublishPacketWithQoS1(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); + _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); + return adapter.SendPacketAsync(response, cancellationToken); } private Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) - _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); + _sessionsManager.EnqueueApplicationMessage(this, applicationMessage); var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); + return adapter.SendPacketAsync(response, cancellationToken); } - private Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) + private static Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }; - return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); + return adapter.SendPacketAsync(response, cancellationToken); } private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index c52c706..1cb301a 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -6,27 +6,29 @@ using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; -using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttClientSessionsManager : IDisposable + public class MqttClientSessionsManager { + private readonly BlockingCollection _messageQueue = new BlockingCollection(); private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); - private readonly AsyncLock _sessionPreparationLock = new AsyncLock(); + + private readonly CancellationToken _cancellationToken; private readonly MqttRetainedMessagesManager _retainedMessagesManager; private readonly IMqttServerOptions _options; private readonly IMqttNetChildLogger _logger; - public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetChildLogger logger) + public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, CancellationToken cancellationToken, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _logger = logger.CreateChildLogger(nameof(MqttClientSessionsManager)); + _cancellationToken = cancellationToken; _options = options ?? throw new ArgumentNullException(nameof(options)); Server = server ?? throw new ArgumentNullException(nameof(server)); _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); @@ -34,7 +36,129 @@ namespace MQTTnet.Server public MqttServer Server { get; } - public async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) + public void Start() + { + Task.Factory.StartNew(() => ProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); + } + + public Task StopAsync() + { + foreach (var session in _sessions) + { + session.Value.Stop(MqttClientDisconnectType.NotClean); + } + + _sessions.Clear(); + return Task.FromResult(0); + } + + public Task StartSession(IMqttChannelAdapter clientAdapter) + { + return Task.Run(() => RunSession(clientAdapter, _cancellationToken), _cancellationToken); + } + + public Task> GetClientStatusAsync() + { + var result = new List(); + foreach (var session in _sessions) + { + var status = new MqttClientSessionStatus(this, session.Value); + session.Value.FillStatus(status); + + result.Add(status); + } + + return Task.FromResult((IList)result); + } + + public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + { + if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + + _messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken); + } + + public Task SubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); + } + + return session.SubscribeAsync(topicFilters); + } + + public Task UnsubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session '{clientId}' is unknown."); + } + + return session.UnsubscribeAsync(topicFilters); + } + + public void DeleteSession(string clientId) + { + _sessions.TryRemove(clientId, out _); + _logger.Verbose("Session for client '{0}' deleted.", clientId); + } + + private void ProcessQueuedApplicationMessages(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + try + { + var enqueuedApplicationMessage = _messageQueue.Take(cancellationToken); + var sender = enqueuedApplicationMessage.Sender; + var applicationMessage = enqueuedApplicationMessage.ApplicationMessage; + + var interceptorContext = InterceptApplicationMessage(sender, applicationMessage); + if (interceptorContext != null) + { + if (interceptorContext.CloseConnection) + { + enqueuedApplicationMessage.Sender.Stop(MqttClientDisconnectType.NotClean); + } + + if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) + { + return; + } + + applicationMessage = interceptorContext.ApplicationMessage; + } + + Server.OnApplicationMessageReceived(sender?.ClientId, applicationMessage); + + if (applicationMessage.Retain) + { + _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult(); + } + + foreach (var clientSession in _sessions.Values) + { + clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage); + } + } + catch (TaskCanceledException) + { + } + catch (Exception exception) + { + _logger.Error(exception, "Unhandled exception while processing queued application message."); + } + } + } + + private async Task RunSession(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) { var clientId = string.Empty; var wasCleanDisconnect = false; @@ -60,7 +184,7 @@ namespace MQTTnet.Server var connectReturnCode = ValidateConnection(connectPacket); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout, + await clientAdapter.SendPacketAsync( new MqttConnAckPacket { ConnectReturnCode = connectReturnCode @@ -70,15 +194,15 @@ namespace MQTTnet.Server return; } - var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); + var result = PrepareClientSession(connectPacket); var clientSession = result.Session; - await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout, + await clientAdapter.SendPacketAsync( new MqttConnAckPacket { ConnectReturnCode = connectReturnCode, IsSessionPresent = result.IsExistingSession - }, + }, cancellationToken).ConfigureAwait(false); Server.OnClientConnected(clientId); @@ -113,73 +237,6 @@ namespace MQTTnet.Server } } - public Task StopAsync() - { - foreach (var session in _sessions) - { - session.Value.Stop(MqttClientDisconnectType.NotClean); - } - - _sessions.Clear(); - return Task.FromResult(0); - } - - public Task> GetClientStatusAsync() - { - var result = new List(); - foreach (var session in _sessions) - { - var status = new MqttClientSessionStatus(this, session.Value); - session.Value.FillStatus(status); - - result.Add(status); - } - - return Task.FromResult((IList)result); - } - - public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) - { - Task.Run(() => DispatchApplicationMessageAsync(senderClientSession, applicationMessage)); - } - - public Task SubscribeAsync(string clientId, IList topicFilters) - { - if (clientId == null) throw new ArgumentNullException(nameof(clientId)); - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - if (!_sessions.TryGetValue(clientId, out var session)) - { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } - - return session.SubscribeAsync(topicFilters); - } - - public Task UnsubscribeAsync(string clientId, IList topicFilters) - { - if (clientId == null) throw new ArgumentNullException(nameof(clientId)); - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - if (!_sessions.TryGetValue(clientId, out var session)) - { - throw new InvalidOperationException($"Client session '{clientId}' is unknown."); - } - - return session.UnsubscribeAsync(topicFilters); - } - - public void DeleteSession(string clientId) - { - _sessions.TryRemove(clientId, out _); - _logger.Verbose("Session for client '{0}' deleted.", clientId); - } - - public void Dispose() - { - _sessionPreparationLock?.Dispose(); - } - private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) { if (_options.ConnectionValidator == null) @@ -197,9 +254,9 @@ namespace MQTTnet.Server return context.ReturnCode; } - private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) + private PrepareClientSessionResult PrepareClientSession(MqttConnectPacket connectPacket) { - using (await _sessionPreparationLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + lock (_sessions) { var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); if (isSessionPresent) @@ -231,60 +288,19 @@ namespace MQTTnet.Server _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } - return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; + return new PrepareClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; } } - private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) { - try - { - var interceptorContext = InterceptApplicationMessage(senderClientSession, applicationMessage); - if (interceptorContext != null) - { - if (interceptorContext.CloseConnection) - { - senderClientSession.Stop(MqttClientDisconnectType.NotClean); - } - - if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) - { - return; - } - - applicationMessage = interceptorContext.ApplicationMessage; - } - - Server.OnApplicationMessageReceived(senderClientSession?.ClientId, applicationMessage); - - if (applicationMessage.Retain) - { - await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false); - } - - foreach (var clientSession in _sessions.Values) - { - clientSession.EnqueueApplicationMessage(senderClientSession, applicationMessage); - } - } - catch (Exception exception) - { - _logger.Error(exception, "Error while processing application message"); - } - } - - private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) - { - var interceptorContext = new MqttApplicationMessageInterceptorContext( - senderClientSession?.ClientId, - applicationMessage); - var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) { - return interceptorContext; + return null; } + var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage); interceptor(interceptorContext); return interceptorContext; } diff --git a/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs new file mode 100644 index 0000000..20ff2fe --- /dev/null +++ b/Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs @@ -0,0 +1,15 @@ +namespace MQTTnet.Server +{ + public class MqttEnqueuedApplicationMessage + { + public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) + { + Sender = sender; + ApplicationMessage = applicationMessage; + } + + public MqttClientSession Sender { get; } + + public MqttApplicationMessage ApplicationMessage { get; } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index a38dfd2..38e631e 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -65,7 +65,7 @@ namespace MQTTnet.Server if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started."); - _clientSessionsManager.StartDispatchApplicationMessage(null, applicationMessage); + _clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage); return Task.FromResult(0); } @@ -81,7 +81,8 @@ namespace MQTTnet.Server _retainedMessagesManager = new MqttRetainedMessagesManager(Options, _logger); await _retainedMessagesManager.LoadMessagesAsync().ConfigureAwait(false); - _clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _logger); + _clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _cancellationTokenSource.Token, _logger); + _clientSessionsManager.Start(); foreach (var adapter in _adapters) { @@ -118,8 +119,6 @@ namespace MQTTnet.Server } finally { - _clientSessionsManager?.Dispose(); - _cancellationTokenSource = null; _retainedMessagesManager = null; _clientSessionsManager = null; @@ -155,9 +154,7 @@ namespace MQTTnet.Server private void OnClientAccepted(object sender, MqttServerAdapterClientAcceptedEventArgs eventArgs) { - eventArgs.SessionTask = Task.Run( - () => _clientSessionsManager.RunSessionAsync(eventArgs.Client, _cancellationTokenSource.Token), - _cancellationTokenSource.Token); + eventArgs.SessionTask = _clientSessionsManager.StartSession(eventArgs.Client); } } } diff --git a/Source/MQTTnet/Server/GetOrCreateClientSessionResult.cs b/Source/MQTTnet/Server/PrepareClientSessionResult.cs similarity index 76% rename from Source/MQTTnet/Server/GetOrCreateClientSessionResult.cs rename to Source/MQTTnet/Server/PrepareClientSessionResult.cs index 975d237..9a655be 100644 --- a/Source/MQTTnet/Server/GetOrCreateClientSessionResult.cs +++ b/Source/MQTTnet/Server/PrepareClientSessionResult.cs @@ -1,6 +1,6 @@ namespace MQTTnet.Server { - public class GetOrCreateClientSessionResult + public class PrepareClientSessionResult { public bool IsExistingSession { get; set; } diff --git a/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs b/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs index 20b581c..7f068d6 100644 --- a/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs @@ -65,7 +65,7 @@ namespace MQTTnet.Benchmarks for (var i = 0; i < 10000; i++) { - _channelAdapter.SendPacketAsync(TimeSpan.FromSeconds(15), _packet, CancellationToken.None).GetAwaiter().GetResult(); + _channelAdapter.SendPacketAsync(_packet, CancellationToken.None).GetAwaiter().GetResult(); } _stream.Position = 0; diff --git a/Tests/MQTTnet.Core.Tests/ByteReaderTests.cs b/Tests/MQTTnet.Core.Tests/ByteReaderTests.cs deleted file mode 100644 index e2173cb..0000000 --- a/Tests/MQTTnet.Core.Tests/ByteReaderTests.cs +++ /dev/null @@ -1,30 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Serializer; - -namespace MQTTnet.Core.Tests -{ - [TestClass] - public class ByteReaderTests - { - [TestMethod] - public void ByteReader_ReadToEnd() - { - var reader = new ByteReader(85); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - Assert.IsTrue(reader.Read()); - Assert.IsFalse(reader.Read()); - } - - [TestMethod] - public void ByteReader_ReadPartial() - { - var reader = new ByteReader(15); - Assert.AreEqual(3, reader.Read(2)); - } - } -} diff --git a/Tests/MQTTnet.Core.Tests/ByteWriterTests.cs b/Tests/MQTTnet.Core.Tests/ByteWriterTests.cs deleted file mode 100644 index 881df5c..0000000 --- a/Tests/MQTTnet.Core.Tests/ByteWriterTests.cs +++ /dev/null @@ -1,51 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Serializer; - -namespace MQTTnet.Core.Tests -{ - [TestClass] - public class ByteWriterTests - { - [TestMethod] - public void ByteWriter_WriteMultipleAll() - { - var b = new ByteWriter(); - Assert.AreEqual(0, b.Value); - b.Write(3, 2); - Assert.AreEqual(3, b.Value); - } - - [TestMethod] - public void ByteWriter_WriteMultiplePartial() - { - var b = new ByteWriter(); - Assert.AreEqual(0, b.Value); - b.Write(255, 2); - Assert.AreEqual(3, b.Value); - } - - [TestMethod] - public void ByteWriter_WriteTo0xFF() - { - var b = new ByteWriter(); - - Assert.AreEqual(0, b.Value); - b.Write(true); - Assert.AreEqual(1, b.Value); - b.Write(true); - Assert.AreEqual(3, b.Value); - b.Write(true); - Assert.AreEqual(7, b.Value); - b.Write(true); - Assert.AreEqual(15, b.Value); - b.Write(true); - Assert.AreEqual(31, b.Value); - b.Write(true); - Assert.AreEqual(63, b.Value); - b.Write(true); - Assert.AreEqual(127, b.Value); - b.Write(true); - Assert.AreEqual(255, b.Value); - } - } -} diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index 696cfa3..fdb5888 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests [TestMethod] public async Task TimeoutAfter() { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] public async Task TimeoutAfterWithResult() { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [TestMethod] public async Task TimeoutAfterCompleteInTime() { - var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); + var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); Assert.AreEqual(5, result); } @@ -36,7 +36,7 @@ namespace MQTTnet.Core.Tests { try { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -55,7 +55,7 @@ namespace MQTTnet.Core.Tests { try { - await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -76,7 +76,7 @@ namespace MQTTnet.Core.Tests var tasks = Enumerable.Range(0, 100000) .Select(i => { - return MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); + return MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); }); await Task.WhenAll(tasks); diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index e3e1ccd..f983622 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Core.Tests public class MqttPacketReaderTests { [TestMethod] - [ExpectedException(typeof(MqttCommunicationException))] + [ExpectedException(typeof(MqttCommunicationClosedGracefullyException))] public void MqttPacketReader_EmptyStream() { MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index ffe2a77..9375f07 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Concurrent; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -36,7 +35,7 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken) + public Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); diff --git a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs index e9e4bc4..752feb5 100644 --- a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs @@ -12,16 +12,46 @@ namespace MQTTnet.TestApp.NetCore { public static class PerformanceTest { - public static async Task RunAsync() + public static void Run() { - Console.WriteLine("Press 'c' for concurrent sends. Otherwise in one batch."); - var concurrent = Console.ReadKey(true).KeyChar == 'c'; + try + { + var mqttServer = new MqttFactory().CreateMqttServer(); + mqttServer.StartAsync(new MqttServerOptions()).GetAwaiter().GetResult(); - var server = Task.Run(RunServerAsync); - await Task.Delay(1000); - var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10), concurrent)); + var options = new MqttClientOptions + { + ChannelOptions = new MqttClientTcpOptions + { + Server = "127.0.0.1" + }, + CleanSession = true + }; + + var client = new MqttFactory().CreateMqttClient(); + client.ConnectAsync(options).GetAwaiter().GetResult(); + + var message = CreateMessage(); + var stopwatch = new Stopwatch(); - await Task.WhenAll(server, client).ConfigureAwait(false); + for (var i = 0; i < 10; i++) + { + stopwatch.Restart(); + + var sentMessagesCount = 0; + while (stopwatch.ElapsedMilliseconds < 1000) + { + client.PublishAsync(message).GetAwaiter().GetResult(); + sentMessagesCount++; + } + + Console.WriteLine($"Sending {sentMessagesCount} messages per second. #" + (i + 1)); + } + } + catch (Exception exception) + { + Console.WriteLine(exception); + } } private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval, bool concurrent) @@ -53,29 +83,8 @@ namespace MQTTnet.TestApp.NetCore } var message = CreateMessage(); - var messages = new[] { message }; - var stopwatch = Stopwatch.StartNew(); - var sentMessagesCount = 0; - while (stopwatch.ElapsedMilliseconds < 1000) - { - client.PublishAsync(messages).GetAwaiter().GetResult(); - sentMessagesCount++; - } - - Console.WriteLine($"Sending {sentMessagesCount} messages per second. #1"); - - sentMessagesCount = 0; - stopwatch.Restart(); - while (stopwatch.ElapsedMilliseconds < 1000) - { - await client.PublishAsync(messages).ConfigureAwait(false); - sentMessagesCount++; - } - - Console.WriteLine($"Sending {sentMessagesCount} messages per second. #2"); - var testMessageCount = 10000; for (var i = 0; i < testMessageCount; i++) { @@ -142,38 +151,5 @@ namespace MQTTnet.TestApp.NetCore Interlocked.Increment(ref count); return Task.Run(() => client.PublishAsync(applicationMessage)); } - - private static async Task RunServerAsync() - { - try - { - var mqttServer = new MqttFactory().CreateMqttServer(); - - ////var msgs = 0; - ////var stopwatch = Stopwatch.StartNew(); - ////mqttServer.ApplicationMessageReceived += (sender, args) => - ////{ - //// msgs++; - //// if (stopwatch.ElapsedMilliseconds > 1000) - //// { - //// Console.WriteLine($"received {msgs}"); - //// msgs = 0; - //// stopwatch.Restart(); - //// } - ////}; - await mqttServer.StartAsync(new MqttServerOptions()); - - Console.WriteLine("Press any key to exit."); - Console.ReadLine(); - - await mqttServer.StopAsync().ConfigureAwait(false); - } - catch (Exception e) - { - Console.WriteLine(e); - } - - Console.ReadLine(); - } } } diff --git a/Tests/MQTTnet.TestApp.NetCore/Program.cs b/Tests/MQTTnet.TestApp.NetCore/Program.cs index f8a5d27..25302c7 100644 --- a/Tests/MQTTnet.TestApp.NetCore/Program.cs +++ b/Tests/MQTTnet.TestApp.NetCore/Program.cs @@ -34,7 +34,8 @@ namespace MQTTnet.TestApp.NetCore } else if (pressedKey.KeyChar == '3') { - Task.Run(PerformanceTest.RunAsync); + PerformanceTest.Run(); + return; } else if (pressedKey.KeyChar == '4') {