diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index a4300ec..76ac050 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -17,8 +17,7 @@ namespace MQTTnet.Adapter public class MqttChannelAdapter : IMqttChannelAdapter { private const uint ErrorOperationAborted = 0x800703E3; - - private static readonly byte[] EmptyBody = new byte[0]; + private const int ReadBufferSize = 4096; // TODO: Move buffer size to config private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly IMqttNetLogger _logger; @@ -91,28 +90,35 @@ namespace MQTTnet.Adapter MqttBasePacket packet = null; await ExecuteAndWrapExceptionAsync(async () => { - ReceivedMqttPacket receivedMqttPacket; - if (timeout > TimeSpan.Zero) - { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); - } - else + ReceivedMqttPacket receivedMqttPacket = null; + try { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); - } + if (timeout > TimeSpan.Zero) + { + receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); + } + else + { + receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); + } - if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) - { - throw new TaskCanceledException(); - } + if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException(); + } + + packet = PacketSerializer.Deserialize(receivedMqttPacket.Header, receivedMqttPacket.Body); + if (packet == null) + { + throw new MqttProtocolViolationException("Received malformed packet."); + } - packet = PacketSerializer.Deserialize(receivedMqttPacket.Header, receivedMqttPacket.Body); - if (packet == null) + _logger.Trace("RX <<< {0}", packet); + } + finally { - throw new MqttProtocolViolationException("Received malformed packet."); + receivedMqttPacket?.Dispose(); } - - _logger.Trace("RX <<< {0}", packet); }).ConfigureAwait(false); return packet; @@ -120,7 +126,7 @@ namespace MQTTnet.Adapter private static async Task ReceiveAsync(Stream stream, CancellationToken cancellationToken) { - var header = await MqttPacketReader.ReadHeaderFromSourceAsync(stream, cancellationToken).ConfigureAwait(false); + var header = await MqttPacketReader.ReadHeaderAsync(stream, cancellationToken).ConfigureAwait(false); if (header == null) { return null; @@ -128,23 +134,26 @@ namespace MQTTnet.Adapter if (header.BodyLength == 0) { - return new ReceivedMqttPacket(header, EmptyBody); + return new ReceivedMqttPacket(header, new MemoryStream(new byte[0], false)); } - var body = new byte[header.BodyLength]; - - var offset = 0; - do + var body = header.BodyLength <= ReadBufferSize ? new MemoryStream(new byte[header.BodyLength]) : new MemoryStream(); + + var buffer = new byte[ReadBufferSize]; + while (body.Length < header.BodyLength) { - var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset, cancellationToken).ConfigureAwait(false); + var readBytesCount = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + // Check if the client closed the connection before sending the full body. if (readBytesCount == 0) { throw new MqttCommunicationException("Connection closed while reading remaining packet body."); } - offset += readBytesCount; - } while (offset < header.BodyLength); + body.Write(buffer, 0, readBytesCount); + } + + body.Seek(0L, SeekOrigin.Begin); return new ReceivedMqttPacket(header, body); } diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs b/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs index c92f9d0..a44fb54 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs @@ -1,11 +1,12 @@ using System; +using System.IO; using MQTTnet.Packets; namespace MQTTnet.Adapter { - public class ReceivedMqttPacket + public sealed class ReceivedMqttPacket : IDisposable { - public ReceivedMqttPacket(MqttPacketHeader header, byte[] body) + public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body) { Header = header ?? throw new ArgumentNullException(nameof(header)); Body = body ?? throw new ArgumentNullException(nameof(body)); @@ -13,6 +14,11 @@ namespace MQTTnet.Adapter public MqttPacketHeader Header { get; } - public byte[] Body { get; } + public MemoryStream Body { get; } + + public void Dispose() + { + Body?.Dispose(); + } } } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs index 2068641..6577b0a 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.IO; using MQTTnet.Packets; namespace MQTTnet.Serializer @@ -10,6 +11,6 @@ namespace MQTTnet.Serializer ICollection> Serialize(MqttBasePacket mqttPacket); - MqttBasePacket Deserialize(MqttPacketHeader header, byte[] body); + MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs index b54085f..9b793cc 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs @@ -22,7 +22,7 @@ namespace MQTTnet.Serializer public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; - public static async Task ReadHeaderFromSourceAsync(Stream stream, CancellationToken cancellationToken) + public static async Task ReadHeaderAsync(Stream stream, CancellationToken cancellationToken) { // Wait for the next package which starts with the header. At this point there will probably // some large delay and thus the thread should be put back to the pool (await). So ReadByte() @@ -36,7 +36,7 @@ namespace MQTTnet.Serializer var fixedHeader = buffer[0]; var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); - var bodyLength = await ReadBodyLengthFromSourceAsync(stream, cancellationToken).ConfigureAwait(false); + var bodyLength = await ReadBodyLengthAsync(stream, cancellationToken).ConfigureAwait(false); return new MqttPacketHeader { @@ -84,7 +84,7 @@ namespace MQTTnet.Serializer return ReadBytes(_header.BodyLength - (int)BaseStream.Position); } - private static async Task ReadBodyLengthFromSourceAsync(Stream stream, CancellationToken cancellationToken) + private static async Task ReadBodyLengthAsync(Stream stream, CancellationToken cancellationToken) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. var multiplier = 1; diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 88414d5..6491e28 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -42,13 +42,12 @@ namespace MQTTnet.Serializer } } - public MqttBasePacket Deserialize(MqttPacketHeader header, byte[] body) + public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) { if (header == null) throw new ArgumentNullException(nameof(header)); if (body == null) throw new ArgumentNullException(nameof(body)); - using (var bodyStream = new MemoryStream(body)) - using (var reader = new MqttPacketReader(header, bodyStream)) + using (var reader = new MqttPacketReader(header, body)) { return Deserialize(header, reader); } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index 4d0e9b0..4d28035 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -12,7 +12,7 @@ namespace MQTTnet.Core.Tests public void MqttPacketReader_EmptyStream() { var memStream = new MemoryStream(); - var header = MqttPacketReader.ReadHeaderFromSourceAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); + var header = MqttPacketReader.ReadHeaderAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); Assert.IsNull(header); } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 5a75490..072e1a9 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -405,11 +405,11 @@ namespace MQTTnet.Core.Tests using (var headerStream = new MemoryStream(Join(buffer1))) { - var header = MqttPacketReader.ReadHeaderFromSourceAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); + var header = MqttPacketReader.ReadHeaderAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.BodyLength)) { - var deserializedPacket = serializer.Deserialize(header, bodyStream.ToArray()); + var deserializedPacket = serializer.Deserialize(header, bodyStream); var buffer2 = serializer.Serialize(deserializedPacket); Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2)));