@@ -17,8 +17,7 @@ namespace MQTTnet.Adapter | |||
public class MqttChannelAdapter : IMqttChannelAdapter | |||
{ | |||
private const uint ErrorOperationAborted = 0x800703E3; | |||
private static readonly byte[] EmptyBody = new byte[0]; | |||
private const int ReadBufferSize = 4096; // TODO: Move buffer size to config | |||
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); | |||
private readonly IMqttNetLogger _logger; | |||
@@ -91,28 +90,35 @@ namespace MQTTnet.Adapter | |||
MqttBasePacket packet = null; | |||
await ExecuteAndWrapExceptionAsync(async () => | |||
{ | |||
ReceivedMqttPacket receivedMqttPacket; | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); | |||
} | |||
else | |||
ReceivedMqttPacket receivedMqttPacket = null; | |||
try | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); | |||
} | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); | |||
} | |||
if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) | |||
{ | |||
throw new TaskCanceledException(); | |||
} | |||
if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) | |||
{ | |||
throw new TaskCanceledException(); | |||
} | |||
packet = PacketSerializer.Deserialize(receivedMqttPacket.Header, receivedMqttPacket.Body); | |||
if (packet == null) | |||
{ | |||
throw new MqttProtocolViolationException("Received malformed packet."); | |||
} | |||
packet = PacketSerializer.Deserialize(receivedMqttPacket.Header, receivedMqttPacket.Body); | |||
if (packet == null) | |||
_logger.Trace<MqttChannelAdapter>("RX <<< {0}", packet); | |||
} | |||
finally | |||
{ | |||
throw new MqttProtocolViolationException("Received malformed packet."); | |||
receivedMqttPacket?.Dispose(); | |||
} | |||
_logger.Trace<MqttChannelAdapter>("RX <<< {0}", packet); | |||
}).ConfigureAwait(false); | |||
return packet; | |||
@@ -120,7 +126,7 @@ namespace MQTTnet.Adapter | |||
private static async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream, CancellationToken cancellationToken) | |||
{ | |||
var header = await MqttPacketReader.ReadHeaderFromSourceAsync(stream, cancellationToken).ConfigureAwait(false); | |||
var header = await MqttPacketReader.ReadHeaderAsync(stream, cancellationToken).ConfigureAwait(false); | |||
if (header == null) | |||
{ | |||
return null; | |||
@@ -128,23 +134,26 @@ namespace MQTTnet.Adapter | |||
if (header.BodyLength == 0) | |||
{ | |||
return new ReceivedMqttPacket(header, EmptyBody); | |||
return new ReceivedMqttPacket(header, new MemoryStream(new byte[0], false)); | |||
} | |||
var body = new byte[header.BodyLength]; | |||
var offset = 0; | |||
do | |||
var body = header.BodyLength <= ReadBufferSize ? new MemoryStream(new byte[header.BodyLength]) : new MemoryStream(); | |||
var buffer = new byte[ReadBufferSize]; | |||
while (body.Length < header.BodyLength) | |||
{ | |||
var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset, cancellationToken).ConfigureAwait(false); | |||
var readBytesCount = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); | |||
// Check if the client closed the connection before sending the full body. | |||
if (readBytesCount == 0) | |||
{ | |||
throw new MqttCommunicationException("Connection closed while reading remaining packet body."); | |||
} | |||
offset += readBytesCount; | |||
} while (offset < header.BodyLength); | |||
body.Write(buffer, 0, readBytesCount); | |||
} | |||
body.Seek(0L, SeekOrigin.Begin); | |||
return new ReceivedMqttPacket(header, body); | |||
} | |||
@@ -1,11 +1,12 @@ | |||
using System; | |||
using System.IO; | |||
using MQTTnet.Packets; | |||
namespace MQTTnet.Adapter | |||
{ | |||
public class ReceivedMqttPacket | |||
public sealed class ReceivedMqttPacket : IDisposable | |||
{ | |||
public ReceivedMqttPacket(MqttPacketHeader header, byte[] body) | |||
public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body) | |||
{ | |||
Header = header ?? throw new ArgumentNullException(nameof(header)); | |||
Body = body ?? throw new ArgumentNullException(nameof(body)); | |||
@@ -13,6 +14,11 @@ namespace MQTTnet.Adapter | |||
public MqttPacketHeader Header { get; } | |||
public byte[] Body { get; } | |||
public MemoryStream Body { get; } | |||
public void Dispose() | |||
{ | |||
Body?.Dispose(); | |||
} | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using MQTTnet.Packets; | |||
namespace MQTTnet.Serializer | |||
@@ -10,6 +11,6 @@ namespace MQTTnet.Serializer | |||
ICollection<ArraySegment<byte>> Serialize(MqttBasePacket mqttPacket); | |||
MqttBasePacket Deserialize(MqttPacketHeader header, byte[] body); | |||
MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); | |||
} | |||
} |
@@ -22,7 +22,7 @@ namespace MQTTnet.Serializer | |||
public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; | |||
public static async Task<MqttPacketHeader> ReadHeaderFromSourceAsync(Stream stream, CancellationToken cancellationToken) | |||
public static async Task<MqttPacketHeader> ReadHeaderAsync(Stream stream, CancellationToken cancellationToken) | |||
{ | |||
// Wait for the next package which starts with the header. At this point there will probably | |||
// some large delay and thus the thread should be put back to the pool (await). So ReadByte() | |||
@@ -36,7 +36,7 @@ namespace MQTTnet.Serializer | |||
var fixedHeader = buffer[0]; | |||
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); | |||
var bodyLength = await ReadBodyLengthFromSourceAsync(stream, cancellationToken).ConfigureAwait(false); | |||
var bodyLength = await ReadBodyLengthAsync(stream, cancellationToken).ConfigureAwait(false); | |||
return new MqttPacketHeader | |||
{ | |||
@@ -84,7 +84,7 @@ namespace MQTTnet.Serializer | |||
return ReadBytes(_header.BodyLength - (int)BaseStream.Position); | |||
} | |||
private static async Task<int> ReadBodyLengthFromSourceAsync(Stream stream, CancellationToken cancellationToken) | |||
private static async Task<int> ReadBodyLengthAsync(Stream stream, CancellationToken cancellationToken) | |||
{ | |||
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | |||
var multiplier = 1; | |||
@@ -42,13 +42,12 @@ namespace MQTTnet.Serializer | |||
} | |||
} | |||
public MqttBasePacket Deserialize(MqttPacketHeader header, byte[] body) | |||
public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) | |||
{ | |||
if (header == null) throw new ArgumentNullException(nameof(header)); | |||
if (body == null) throw new ArgumentNullException(nameof(body)); | |||
using (var bodyStream = new MemoryStream(body)) | |||
using (var reader = new MqttPacketReader(header, bodyStream)) | |||
using (var reader = new MqttPacketReader(header, body)) | |||
{ | |||
return Deserialize(header, reader); | |||
} | |||
@@ -12,7 +12,7 @@ namespace MQTTnet.Core.Tests | |||
public void MqttPacketReader_EmptyStream() | |||
{ | |||
var memStream = new MemoryStream(); | |||
var header = MqttPacketReader.ReadHeaderFromSourceAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); | |||
var header = MqttPacketReader.ReadHeaderAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); | |||
Assert.IsNull(header); | |||
} | |||
@@ -405,11 +405,11 @@ namespace MQTTnet.Core.Tests | |||
using (var headerStream = new MemoryStream(Join(buffer1))) | |||
{ | |||
var header = MqttPacketReader.ReadHeaderFromSourceAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); | |||
var header = MqttPacketReader.ReadHeaderAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.BodyLength)) | |||
{ | |||
var deserializedPacket = serializer.Deserialize(header, bodyStream.ToArray()); | |||
var deserializedPacket = serializer.Deserialize(header, bodyStream); | |||
var buffer2 = serializer.Serialize(deserializedPacket); | |||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2))); | |||