From 76105de4c7b28736e8d51739936f6c487a2f86f1 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Fri, 15 Sep 2017 15:26:26 +0200 Subject: [PATCH] Refactor latest changes --- .../Implementations/MqttTcpChannel.cs | 67 +++++------ .../Implementations/MqttWebSocketChannel.cs | 1 - .../Implementations/WebSocketStream.cs | 9 +- .../Implementations/MqttTcpChannel.cs | 53 ++++----- .../Implementations/MqttTcpChannel.cs | 97 ++++++---------- .../Implementations/MqttWebSocketChannel.cs | 1 - .../MQTTnet.UniversalWindows.csproj | 5 + .../MqttChannelCommunicationAdapter.cs | 106 +++++++++++------- MQTTnet.Core/Adapter/ReceivedMqttPacket.cs | 19 ++++ MQTTnet.Core/Client/MqttClient.cs | 5 + MQTTnet.Core/Client/MqttPacketDispatcher.cs | 12 +- MQTTnet.Core/Internal/TaskExtensions.cs | 24 ++-- .../Serializer/IMqttPacketSerializer.cs | 4 +- MQTTnet.Core/Serializer/MqttPacketReader.cs | 45 ++++---- .../Serializer/MqttPacketSerializer.cs | 10 +- .../MqttPacketSerializerTests.cs | 17 +-- .../PerformanceTest.cs | 53 +++++---- .../MQTTnet.TestApp.UniversalWindows.csproj | 2 +- 18 files changed, 268 insertions(+), 262 deletions(-) create mode 100644 MQTTnet.Core/Adapter/ReceivedMqttPacket.cs diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs index 33ba567..f7d8563 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs @@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; -using MQTTnet.Core.Exceptions; using System.IO; namespace MQTTnet.Implementations @@ -16,10 +15,6 @@ namespace MQTTnet.Implementations private Socket _socket; private SslStream _sslStream; - public Stream RawStream { get; private set; } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } - /// /// called on client sockets are created in connect /// @@ -36,61 +31,61 @@ namespace MQTTnet.Implementations { _socket = socket ?? throw new ArgumentNullException(nameof(socket)); _sslStream = sslStream; - CreateCommStreams(socket, sslStream); + CreateStreams(socket, sslStream); } + public Stream RawStream { get; private set; } + public Stream SendStream { get; private set; } + public Stream ReceiveStream { get; private set; } + public async Task ConnectAsync(MqttClientOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); - try - { - if (_socket == null) - { - _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); - } - - await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false); - if (options.TlsOptions.UseTls) - { - _sslStream = new SslStream(new NetworkStream(_socket, true)); + if (_socket == null) + { + _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + } - await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); - } + await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false); - CreateCommStreams(_socket, _sslStream); - } - catch (SocketException exception) + if (options.TlsOptions.UseTls) { - throw new MqttCommunicationException(exception); + _sslStream = new SslStream(new NetworkStream(_socket, true)); + + await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); } + + CreateStreams(_socket, _sslStream); } public Task DisconnectAsync() { - try - { - Dispose(); - return Task.FromResult(0); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); - } + Dispose(); + return Task.FromResult(0); } public void Dispose() { - _socket?.Dispose(); - _sslStream?.Dispose(); + RawStream?.Dispose(); + RawStream = null; + + ReceiveStream?.Dispose(); + ReceiveStream = null; + + SendStream?.Dispose(); + SendStream = null; + _socket?.Dispose(); _socket = null; + + _sslStream?.Dispose(); _sslStream = null; } - private void CreateCommStreams(Socket socket, SslStream sslStream) + private void CreateStreams(Socket socket, Stream sslStream) { - RawStream = (Stream)sslStream ?? new NetworkStream(socket); + RawStream = sslStream ?? new NetworkStream(socket); //cannot use this as default buffering prevents from receiving the first connect message //need two streams otherwise read and write have to be synchronized diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs index bc224cb..b639fe0 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs @@ -14,7 +14,6 @@ namespace MQTTnet.Implementations private ClientWebSocket _webSocket = new ClientWebSocket(); public Stream RawStream { get; private set; } - public Stream SendStream => RawStream; public Stream ReceiveStream => RawStream; diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs b/Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs index d912148..4f05936 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs @@ -56,15 +56,12 @@ namespace MQTTnet.Implementations public override bool CanSeek => false; public override bool CanWrite => true; - public override long Length - { - get { throw new NotSupportedException(); } - } + public override long Length => throw new NotSupportedException(); public override long Position { - get { throw new NotSupportedException(); } - set { throw new NotSupportedException(); } + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); } public override long Seek(long offset, SeekOrigin origin) diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index 0ef5dad..962f763 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; -using MQTTnet.Core.Exceptions; using System.IO; namespace MQTTnet.Implementations @@ -16,10 +15,6 @@ namespace MQTTnet.Implementations private Socket _socket; private SslStream _sslStream; - public Stream ReceiveStream { get; private set; } - public Stream RawStream => ReceiveStream; - public Stream SendStream => ReceiveStream; - /// /// called on client sockets are created in connect /// @@ -38,55 +33,45 @@ namespace MQTTnet.Implementations ReceiveStream = (Stream)sslStream ?? new NetworkStream(socket); } + public Stream ReceiveStream { get; private set; } + public Stream RawStream => ReceiveStream; + public Stream SendStream => ReceiveStream; + public async Task ConnectAsync(MqttClientOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); - try + if (_socket == null) { - if (_socket == null) - { - _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); - } + _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + } - await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); + await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); - if (options.TlsOptions.UseTls) - { - _sslStream = new SslStream(new NetworkStream(_socket, true)); - ReceiveStream = _sslStream; - await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); - } - else - { - ReceiveStream = new NetworkStream(_socket); - } + if (options.TlsOptions.UseTls) + { + _sslStream = new SslStream(new NetworkStream(_socket, true)); + ReceiveStream = _sslStream; + await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); } - catch (SocketException exception) + else { - throw new MqttCommunicationException(exception); + ReceiveStream = new NetworkStream(_socket); } } public Task DisconnectAsync() { - try - { - Dispose(); - return Task.FromResult(0); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); - } + Dispose(); + return Task.FromResult(0); } public void Dispose() { _socket?.Dispose(); - _sslStream?.Dispose(); - _socket = null; + + _sslStream?.Dispose(); _sslStream = null; } diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs index eeae7a6..c1fb9a4 100644 --- a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs @@ -1,15 +1,13 @@ using System; +using System.IO; using System.Linq; -using System.Net.Sockets; using System.Runtime.InteropServices.WindowsRuntime; using System.Threading.Tasks; using Windows.Networking; using Windows.Networking.Sockets; using Windows.Security.Cryptography.Certificates; -using Windows.Storage.Streams; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; -using MQTTnet.Core.Exceptions; namespace MQTTnet.Implementations { @@ -26,89 +24,58 @@ namespace MQTTnet.Implementations _socket = socket ?? throw new ArgumentNullException(nameof(socket)); } + public Stream SendStream { get; private set; } + public Stream ReceiveStream { get; private set; } + public Stream RawStream { get; private set; } + public async Task ConnectAsync(MqttClientOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); - try - { - if (_socket == null) - { - _socket = new StreamSocket(); - } - - if (!options.TlsOptions.UseTls) - { - await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString()); - } - else - { - _socket.Control.ClientCertificate = LoadCertificate(options); - - if (!options.TlsOptions.CheckCertificateRevocation) - { - _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain); - _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing); - } - await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12); - } - } - catch (SocketException exception) + if (_socket == null) { - throw new MqttCommunicationException(exception); + _socket = new StreamSocket(); } - } - public Task DisconnectAsync() - { - try + if (!options.TlsOptions.UseTls) { - Dispose(); - return Task.FromResult(0); + await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString()); } - catch (SocketException exception) + else { - throw new MqttCommunicationException(exception); - } - } + _socket.Control.ClientCertificate = LoadCertificate(options); - public async Task WriteAsync(byte[] buffer) - { - if (buffer == null) throw new ArgumentNullException(nameof(buffer)); + if (!options.TlsOptions.CheckCertificateRevocation) + { + _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain); + _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing); + } - try - { - await _socket.OutputStream.WriteAsync(buffer.AsBuffer()); - await _socket.OutputStream.FlushAsync(); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); + await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12); } - } - public int Peek() - { - return 0; + ReceiveStream = _socket.InputStream.AsStreamForRead(); + SendStream = _socket.OutputStream.AsStreamForWrite(); + RawStream = ReceiveStream; } - public async Task> ReadAsync(int length, byte[] buffer) + public Task DisconnectAsync() { - if (buffer == null) throw new ArgumentNullException(nameof(buffer)); - - try - { - var result = await _socket.InputStream.ReadAsync(buffer.AsBuffer(), (uint)buffer.Length, InputStreamOptions.None); - return new ArraySegment(buffer, 0, (int)result.Length); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); - } + Dispose(); + return Task.FromResult(0); } public void Dispose() { + RawStream?.Dispose(); + RawStream = null; + + SendStream?.Dispose(); + SendStream = null; + + ReceiveStream?.Dispose(); + ReceiveStream = null; + _socket?.Dispose(); _socket = null; } diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs index bc224cb..b639fe0 100644 --- a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs @@ -14,7 +14,6 @@ namespace MQTTnet.Implementations private ClientWebSocket _webSocket = new ClientWebSocket(); public Stream RawStream { get; private set; } - public Stream SendStream => RawStream; public Stream ReceiveStream => RawStream; diff --git a/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj b/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj index 8767e24..f4205c3 100644 --- a/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj +++ b/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj @@ -130,6 +130,11 @@ 5.3.3 + + + ..\..\..\..\Program Files\dotnet\sdk\NuGetFallbackFolder\microsoft.netcore.app\2.0.0\ref\netcoreapp2.0\System.Net.Security.dll + + 14.0 diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index 241bd54..a0239d5 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -15,7 +15,6 @@ namespace MQTTnet.Core.Adapter public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter { private readonly IMqttCommunicationChannel _channel; - private readonly byte[] _readBuffer = new byte[BufferConstants.Size]; private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write @@ -29,76 +28,105 @@ namespace MQTTnet.Core.Adapter public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) { - return _channel.ConnectAsync(options).TimeoutAfter(timeout); + try + { + return _channel.ConnectAsync(options).TimeoutAfter(timeout); + } + catch (Exception exception) + { + throw new MqttCommunicationException(exception); + } } public Task DisconnectAsync() { - return _channel.DisconnectAsync(); + try + { + return _channel.DisconnectAsync(); + } + catch (Exception exception) + { + throw new MqttCommunicationException(exception); + } } public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets) { - lock (_channel) + try { - foreach (var packet in packets) + lock (_channel) { - MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); + foreach (var packet in packets) + { + MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); - var writeBuffer = PacketSerializer.Serialize(packet); - - _sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); + var writeBuffer = PacketSerializer.Serialize(packet); + _sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); + } } - } - await _sendTask; // configure await false geneates stackoverflow - await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); + await _sendTask; // configure await false geneates stackoverflow + await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); + } + catch (Exception exception) + { + throw new MqttCommunicationException(exception); + } } public async Task ReceivePacketAsync(TimeSpan timeout) { - Tuple tuple; - if (timeout > TimeSpan.Zero) - { - tuple = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); - } - else + try { - tuple = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); - } + ReceivedMqttPacket receivedMqttPacket; + if (timeout > TimeSpan.Zero) + { + receivedMqttPacket = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); + } + else + { + receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); + } - var packet = PacketSerializer.Deserialize(tuple.Item1, tuple.Item2); + var packet = PacketSerializer.Deserialize(receivedMqttPacket); + if (packet == null) + { + throw new MqttProtocolViolationException("Received malformed packet."); + } - if (packet == null) + MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); + return packet; + } + catch (Exception exception) { - throw new MqttProtocolViolationException("Received malformed packet."); + throw new MqttCommunicationException(exception); } - - MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); - return packet; } - private async Task> ReceiveAsync(Stream stream) + private async Task ReceiveAsync(Stream stream) { var header = MqttPacketReader.ReadHeaderFromSource(stream); - MemoryStream body; - if (header.BodyLength > 0) + if (header.BodyLength == 0) { - var totalRead = 0; - do - { - var read = await stream.ReadAsync(_readBuffer, totalRead, header.BodyLength - totalRead).ConfigureAwait(false); - totalRead += read; - } while (totalRead < header.BodyLength); - body = new MemoryStream(_readBuffer, 0, header.BodyLength); + return new ReceivedMqttPacket(header, new MemoryStream(0)); } - else + + var body = new byte[header.BodyLength]; + + var offset = 0; + do + { + var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset).ConfigureAwait(false); + offset += readBytesCount; + } while (offset < header.BodyLength); + + if (offset > header.BodyLength) { - body = new MemoryStream(); + throw new MqttCommunicationException($"Read more body bytes than required ({offset}/{header.BodyLength})."); } - return Tuple.Create(header, body); + return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length)); } } } \ No newline at end of file diff --git a/MQTTnet.Core/Adapter/ReceivedMqttPacket.cs b/MQTTnet.Core/Adapter/ReceivedMqttPacket.cs new file mode 100644 index 0000000..f8adc2e --- /dev/null +++ b/MQTTnet.Core/Adapter/ReceivedMqttPacket.cs @@ -0,0 +1,19 @@ +using System; +using System.IO; +using MQTTnet.Core.Packets; + +namespace MQTTnet.Core.Adapter +{ + public class ReceivedMqttPacket + { + public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body) + { + Header = header ?? throw new ArgumentNullException(nameof(header)); + Body = body ?? throw new ArgumentNullException(nameof(body)); + } + + public MqttPacketHeader Header { get; } + + public MemoryStream Body { get; } + } +} diff --git a/MQTTnet.Core/Client/MqttClient.cs b/MQTTnet.Core/Client/MqttClient.cs index aabc5a6..3239f54 100644 --- a/MQTTnet.Core/Client/MqttClient.cs +++ b/MQTTnet.Core/Client/MqttClient.cs @@ -99,6 +99,11 @@ namespace MQTTnet.Core.Client public async Task DisconnectAsync() { + if (!IsConnected) + { + return; + } + try { await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); diff --git a/MQTTnet.Core/Client/MqttPacketDispatcher.cs b/MQTTnet.Core/Client/MqttPacketDispatcher.cs index 21c177c..e494509 100644 --- a/MQTTnet.Core/Client/MqttPacketDispatcher.cs +++ b/MQTTnet.Core/Client/MqttPacketDispatcher.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Core.Client public class MqttPacketDispatcher { private readonly ConcurrentDictionary> _packetByResponseType = new ConcurrentDictionary>(); - private readonly ConcurrentDictionary>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary>>(); + private readonly ConcurrentDictionary>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary>>(); public async Task WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) { @@ -24,7 +24,7 @@ namespace MQTTnet.Core.Client } catch (MqttCommunicationTimedOutException) { - MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); + MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet of type '{0}'.", responseType.Name); throw; } finally @@ -42,16 +42,20 @@ namespace MQTTnet.Core.Client { if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) { - if (byid.TryRemove( withIdentifier.PacketIdentifier, out var tcs)) + if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) { - tcs.TrySetResult( packet ); + tcs.TrySetResult(packet); + return; } } } else if (_packetByResponseType.TryRemove(type, out var tcs)) { tcs.TrySetResult(packet); + return; } + + throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched."); } public void Reset() diff --git a/MQTTnet.Core/Internal/TaskExtensions.cs b/MQTTnet.Core/Internal/TaskExtensions.cs index 485bd57..67890f4 100644 --- a/MQTTnet.Core/Internal/TaskExtensions.cs +++ b/MQTTnet.Core/Internal/TaskExtensions.cs @@ -25,20 +25,22 @@ namespace MQTTnet.Core.Internal try { - cancellationTokenSource.CancelAfter(timeout); + #pragma warning disable 4014 task.ContinueWith(t => - { - if (t.IsFaulted) - { - tcs.TrySetException(t.Exception); - } + #pragma warning restore 4014 + { + if (t.IsFaulted) + { + tcs.TrySetException(t.Exception); + } - if (t.IsCompleted) - { - tcs.TrySetResult(t.Result); - } - }, cancellationTokenSource.Token); + if (t.IsCompleted) + { + tcs.TrySetResult(t.Result); + } + }, cancellationTokenSource.Token); + cancellationTokenSource.CancelAfter(timeout); return await tcs.Task; } catch (TaskCanceledException) diff --git a/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs b/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs index df5b045..5834c53 100644 --- a/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs +++ b/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs @@ -1,4 +1,4 @@ -using System.IO; +using MQTTnet.Core.Adapter; using MQTTnet.Core.Packets; namespace MQTTnet.Core.Serializer @@ -9,6 +9,6 @@ namespace MQTTnet.Core.Serializer byte[] Serialize(MqttBasePacket mqttPacket); - MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream stream); + MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket); } } \ No newline at end of file diff --git a/MQTTnet.Core/Serializer/MqttPacketReader.cs b/MQTTnet.Core/Serializer/MqttPacketReader.cs index def590b..f7fbdb2 100644 --- a/MQTTnet.Core/Serializer/MqttPacketReader.cs +++ b/MQTTnet.Core/Serializer/MqttPacketReader.cs @@ -1,6 +1,7 @@ using System; using System.IO; using System.Text; +using MQTTnet.Core.Adapter; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Protocol; using MQTTnet.Core.Packets; @@ -9,15 +10,29 @@ namespace MQTTnet.Core.Serializer { public sealed class MqttPacketReader : BinaryReader { - private readonly MqttPacketHeader _header; - - public MqttPacketReader(Stream stream, MqttPacketHeader header) - : base(stream, Encoding.UTF8, true) + private readonly ReceivedMqttPacket _receivedMqttPacket; + + public MqttPacketReader(ReceivedMqttPacket receivedMqttPacket) + : base(receivedMqttPacket.Body, Encoding.UTF8, true) { - _header = header; + _receivedMqttPacket = receivedMqttPacket; } - public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; + public bool EndOfRemainingData => BaseStream.Position == _receivedMqttPacket.Header.BodyLength; + + public static MqttPacketHeader ReadHeaderFromSource(Stream stream) + { + var fixedHeader = (byte)stream.ReadByte(); + var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); + var bodyLength = ReadBodyLengthFromSource(stream); + + return new MqttPacketHeader + { + FixedHeader = fixedHeader, + ControlPacketType = controlPacketType, + BodyLength = bodyLength + }; + } public override ushort ReadUInt16() { @@ -44,21 +59,7 @@ namespace MQTTnet.Core.Serializer public byte[] ReadRemainingData() { - return ReadBytes(_header.BodyLength - (int)BaseStream.Position); - } - - public static MqttPacketHeader ReadHeaderFromSource(Stream stream) - { - var fixedHeader = (byte)stream.ReadByte(); - var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); - var bodyLength = ReadBodyLengthFromSource(stream); - - return new MqttPacketHeader - { - FixedHeader = fixedHeader, - ControlPacketType = controlPacketType, - BodyLength = bodyLength - }; + return ReadBytes(_receivedMqttPacket.Header.BodyLength - (int)BaseStream.Position); } private static int ReadBodyLengthFromSource(Stream stream) @@ -74,7 +75,7 @@ namespace MQTTnet.Core.Serializer multiplier *= 128; if (multiplier > 128 * 128 * 128) { - throw new MqttProtocolViolationException("Remaining length is ivalid."); + throw new MqttProtocolViolationException("Remaining length is invalid."); } } while ((encodedByte & 128) != 0); return value; diff --git a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs index 5dea752..fcec63c 100644 --- a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs +++ b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using MQTTnet.Core.Adapter; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Packets; using MQTTnet.Core.Protocol; @@ -110,14 +111,13 @@ namespace MQTTnet.Core.Serializer throw new MqttProtocolViolationException("Packet type invalid."); } - public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) + public MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket) { - if (header == null) throw new ArgumentNullException(nameof(header)); - if (body == null) throw new ArgumentNullException(nameof(body)); + if (receivedMqttPacket == null) throw new ArgumentNullException(nameof(receivedMqttPacket)); - using (var reader = new MqttPacketReader(body, header)) + using (var reader = new MqttPacketReader(receivedMqttPacket)) { - return Deserialize(header, reader); + return Deserialize(receivedMqttPacket.Header, reader); } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index e0ae8fa..c759f23 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -3,6 +3,7 @@ using System.IO; using System.Text; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Adapter; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; using MQTTnet.Core.Packets; @@ -436,20 +437,20 @@ namespace MQTTnet.Core.Tests private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value) { var serializer = new MqttPacketSerializer(); - + var buffer1 = serializer.Serialize(packet); - using (var headerStream = new MemoryStream( buffer1 )) + using (var headerStream = new MemoryStream(buffer1)) { - var header = MqttPacketReader.ReadHeaderFromSource( headerStream ); + var header = MqttPacketReader.ReadHeaderFromSource(headerStream); - using (var bodyStream = new MemoryStream( buffer1, (int)headerStream.Position, header.BodyLength )) + using (var bodyStream = new MemoryStream(buffer1, (int)headerStream.Position, header.BodyLength)) { - var deserializedPacket = serializer.Deserialize(header, bodyStream); - var buffer2 = serializer.Serialize( deserializedPacket ); + var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket(header, bodyStream)); + var buffer2 = serializer.Serialize(deserializedPacket); - Assert.AreEqual( expectedBase64Value, Convert.ToBase64String( buffer2 ) ); - } + Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(buffer2)); + } } } } diff --git a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs index de5cd27..2aee406 100644 --- a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs @@ -17,18 +17,18 @@ namespace MQTTnet.TestApp.NetFramework { public static async Task RunAsync() { - var server = Task.Run(() => RunServerAsync()); - var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10))); + var server = Task.Factory.StartNew(RunServerAsync, TaskCreationOptions.LongRunning); + var client = Task.Factory.StartNew(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10)), TaskCreationOptions.LongRunning); await Task.WhenAll(server, client).ConfigureAwait(false); } private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval) { - return Task.WhenAll(Enumerable.Range(0, 3).Select((i) => Task.Run(() => RunClientAsync(msgChunkSize, interval)))); + return Task.WhenAll(Enumerable.Range(0, 3).Select(i => Task.Run(() => RunClientAsync(msgChunkSize, interval)))); } - private static async Task RunClientAsync( int msgChunkSize, TimeSpan interval ) + private static async Task RunClientAsync(int msgChunkSize, TimeSpan interval) { try { @@ -83,7 +83,7 @@ namespace MQTTnet.TestApp.NetFramework Console.WriteLine("### WAITING FOR APPLICATION MESSAGES ###"); - var testMessageCount = 1000; + var testMessageCount = 10000; var message = CreateMessage(); var stopwatch = Stopwatch.StartNew(); for (var i = 0; i < testMessageCount; i++) @@ -92,8 +92,8 @@ namespace MQTTnet.TestApp.NetFramework } stopwatch.Stop(); - Console.WriteLine($"Sent 1000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message)."); - + Console.WriteLine($"Sent 10.000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message)."); + stopwatch.Restart(); var sentMessagesCount = 0; while (stopwatch.ElapsedMilliseconds < 1000) @@ -109,32 +109,32 @@ namespace MQTTnet.TestApp.NetFramework while (true) { - var msgs = Enumerable.Range( 0, msgChunkSize ) - .Select( i => CreateMessage() ) + var msgs = Enumerable.Range(0, msgChunkSize) + .Select(i => CreateMessage()) .ToList(); if (false) { //send concurrent (test for raceconditions) var sendTasks = msgs - .Select( msg => PublishSingleMessage( client, msg, ref msgCount ) ) + .Select(msg => PublishSingleMessage(client, msg, ref msgCount)) .ToList(); - await Task.WhenAll( sendTasks ); + await Task.WhenAll(sendTasks); } else { - await client.PublishAsync( msgs ); + await client.PublishAsync(msgs); msgCount += msgs.Count; //send multiple } - + var now = DateTime.Now; if (last < now - TimeSpan.FromSeconds(1)) { - Console.WriteLine( $"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}" ); + Console.WriteLine($"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}"); msgCount = 0; last = now; } @@ -152,19 +152,19 @@ namespace MQTTnet.TestApp.NetFramework { return new MqttApplicationMessage( "A/B/C", - Encoding.UTF8.GetBytes( "Hello World" ), + Encoding.UTF8.GetBytes("Hello World"), MqttQualityOfServiceLevel.AtMostOnce, false ); } - private static Task PublishSingleMessage( IMqttClient client, MqttApplicationMessage applicationMessage, ref int count ) + private static Task PublishSingleMessage(IMqttClient client, MqttApplicationMessage applicationMessage, ref int count) { - Interlocked.Increment( ref count ); - return Task.Run( () => - { - return client.PublishAsync( applicationMessage ); - } ); + Interlocked.Increment(ref count); + return Task.Run(() => + { + return client.PublishAsync(applicationMessage); + }); } private static void RunServerAsync() @@ -187,19 +187,18 @@ namespace MQTTnet.TestApp.NetFramework }, DefaultCommunicationTimeout = TimeSpan.FromMinutes(10) }; - + var mqttServer = new MqttServerFactory().CreateMqttServer(options); - var last = DateTime.Now; var msgs = 0; - mqttServer.ApplicationMessageReceived += (sender, args) => + var stopwatch = Stopwatch.StartNew(); + mqttServer.ApplicationMessageReceived += (sender, args) => { msgs++; - var now = DateTime.Now; - if (last < now - TimeSpan.FromSeconds(1)) + if (stopwatch.ElapsedMilliseconds > 1000) { Console.WriteLine($"received {msgs}"); msgs = 0; - last = now; + stopwatch.Restart(); } }; mqttServer.Start(); diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj b/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj index 502c3be..fde3fa4 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj @@ -42,7 +42,7 @@ false prompt true - true + false true