@@ -139,37 +139,36 @@ namespace MQTTnet.Adapter | |||
{ | |||
ReadingPacketStarted?.Invoke(this, EventArgs.Empty); | |||
var bodyLength = await MqttPacketReader.ReadBodyLengthAsync(channel, cancellationToken).ConfigureAwait(false); | |||
if (bodyLength == 0) | |||
if (fixedHeader.RemainingLength == 0) | |||
{ | |||
return new ReceivedMqttPacket(fixedHeader, null); | |||
return new ReceivedMqttPacket(fixedHeader.Flags, null); | |||
} | |||
var body = new MemoryStream(bodyLength); | |||
var body = new MemoryStream(fixedHeader.RemainingLength); | |||
var buffer = new byte[Math.Min(ReadBufferSize, bodyLength)]; | |||
while (body.Length < bodyLength) | |||
var buffer = new byte[Math.Min(ReadBufferSize, fixedHeader.RemainingLength)]; | |||
while (body.Length < fixedHeader.RemainingLength) | |||
{ | |||
var bytesLeft = bodyLength - (int)body.Length; | |||
var bytesLeft = fixedHeader.RemainingLength - (int)body.Length; | |||
if (bytesLeft > buffer.Length) | |||
{ | |||
bytesLeft = buffer.Length; | |||
} | |||
var readBytesCount = await channel.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); | |||
if (readBytesCount <= 0) | |||
var readBytes = await channel.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); | |||
if (readBytes <= 0) | |||
{ | |||
ExceptionHelper.ThrowGracefulSocketClose(); | |||
} | |||
// Here is no need to await because internally only an array is used and no real I/O operation is made. | |||
// Using async here will only generate overhead. | |||
body.Write(buffer, 0, readBytesCount); | |||
body.Write(buffer, 0, readBytes); | |||
} | |||
body.Seek(0L, SeekOrigin.Begin); | |||
return new ReceivedMqttPacket(fixedHeader, body); | |||
return new ReceivedMqttPacket(fixedHeader.Flags, body); | |||
} | |||
finally | |||
{ | |||
@@ -13,7 +13,7 @@ namespace MQTTnet.Client | |||
{ | |||
foreach (var awaiter in _awaiters) | |||
{ | |||
Task.Run(() => awaiter.Value.SetException(exception)); // Task.Run fixes a dead lock. Without this the client only receives one message. | |||
Task.Run(() => awaiter.Value.TrySetException(exception)); // Task.Run fixes a dead lock. Without this the client only receives one message. | |||
} | |||
} | |||
@@ -32,8 +32,7 @@ namespace MQTTnet.Client | |||
if (_awaiters.TryRemove(key, out var awaiter)) | |||
{ | |||
awaiter.SetResult(packet); | |||
Task.Run(() => awaiter.SetResult(packet)); // Task.Run fixes a dead lock. Without this the client only receives one message. | |||
Task.Run(() => awaiter.TrySetResult(packet)); // Task.Run fixes a dead lock. Without this the client only receives one message. | |||
return; | |||
} | |||
@@ -0,0 +1,15 @@ | |||
namespace MQTTnet.Serializer | |||
{ | |||
public struct MqttFixedHeader | |||
{ | |||
public MqttFixedHeader(byte flags, int remainingLength) | |||
{ | |||
Flags = flags; | |||
RemainingLength = remainingLength; | |||
} | |||
public byte Flags { get; } | |||
public int RemainingLength { get; } | |||
} | |||
} |
@@ -11,19 +11,32 @@ namespace MQTTnet.Serializer | |||
{ | |||
public static class MqttPacketReader | |||
{ | |||
public static async Task<byte> ReadFixedHeaderAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||
public static async Task<MqttFixedHeader> ReadFixedHeaderAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||
{ | |||
// Wait for the next package which starts with the header. At this point there will probably | |||
// some large delay and thus the thread should be put back to the pool (await). So ReadByte() | |||
// is not an option here. | |||
var buffer = new byte[1]; | |||
var readCount = await channel.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); | |||
if (readCount <= 0) | |||
// The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. | |||
// So in all cases at least 2 bytes must be read for a complete MQTT packet. | |||
var buffer = new byte[2]; | |||
var totalBytesRead = 0; | |||
while (totalBytesRead < buffer.Length) | |||
{ | |||
ExceptionHelper.ThrowGracefulSocketClose(); | |||
var bytesRead = await channel.ReadAsync(buffer, 0, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); | |||
if (bytesRead <= 0) | |||
{ | |||
ExceptionHelper.ThrowGracefulSocketClose(); | |||
} | |||
totalBytesRead += bytesRead; | |||
} | |||
return buffer[0]; | |||
var hasRemainingLength = buffer[1] != 0; | |||
if (!hasRemainingLength) | |||
{ | |||
return new MqttFixedHeader(buffer[0], 0); | |||
} | |||
var bodyLength = await ReadBodyLengthAsync(channel, buffer[1], cancellationToken); | |||
return new MqttFixedHeader(buffer[0], bodyLength); | |||
} | |||
public static ushort ReadUInt16(this Stream stream) | |||
@@ -64,21 +77,29 @@ namespace MQTTnet.Serializer | |||
return stream.ReadBytes((int)(stream.Length - stream.Position)); | |||
} | |||
public static async Task<int> ReadBodyLengthAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||
private static byte[] ReadBytes(this Stream stream, int count) | |||
{ | |||
var buffer = new byte[count]; | |||
var readBytes = stream.Read(buffer, 0, count); | |||
if (readBytes != count) | |||
{ | |||
throw new InvalidOperationException($"Unable to read {count} bytes from the stream."); | |||
} | |||
return buffer; | |||
} | |||
private static async Task<int> ReadBodyLengthAsync(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) | |||
{ | |||
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | |||
var multiplier = 1; | |||
var value = 0; | |||
int encodedByte; | |||
var value = (byte)(initialEncodedByte & 127) * multiplier; | |||
int encodedByte = initialEncodedByte; | |||
var buffer = new byte[1]; | |||
do | |||
while ((encodedByte & 128) != 0) | |||
{ | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
throw new TaskCanceledException(); | |||
} | |||
var readCount = await channel.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); | |||
if (readCount <= 0) | |||
{ | |||
@@ -94,16 +115,9 @@ namespace MQTTnet.Serializer | |||
} | |||
multiplier *= 128; | |||
} while ((encodedByte & 128) != 0); | |||
} | |||
return value; | |||
} | |||
private static byte[] ReadBytes(this Stream stream, int count) | |||
{ | |||
var buffer = new byte[count]; | |||
stream.Read(buffer, 0, count); | |||
return buffer; | |||
} | |||
} | |||
} |
@@ -10,6 +10,8 @@ namespace MQTTnet.Serializer | |||
{ | |||
public class MqttPacketSerializer : IMqttPacketSerializer | |||
{ | |||
private const int FixedHeaderSize = 1; | |||
public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; | |||
public ArraySegment<byte> Serialize(MqttBasePacket packet) | |||
@@ -18,27 +20,33 @@ namespace MQTTnet.Serializer | |||
using (var stream = new MemoryStream(128)) | |||
{ | |||
// Leave enough head space for max header size (fixed + 4 variable remaining length) | |||
stream.Position = 5; | |||
// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) | |||
stream.Seek(5, SeekOrigin.Begin); | |||
var fixedHeader = SerializePacket(packet, stream); | |||
var remainingLength = (int)stream.Length - 5; | |||
stream.Position = 1; | |||
var remainingLength = MqttPacketWriter.EncodeRemainingLength((int)stream.Length - 5, stream); | |||
var remainingLengthSize = MqttPacketWriter.WriteRemainingLength(remainingLength, stream); | |||
var headerSize = remainingLength + 1; | |||
var headerSize = FixedHeaderSize + remainingLengthSize; | |||
var headerOffset = 5 - headerSize; | |||
// Position cursor on correct offset on beginining of array (has leading 0x0) | |||
stream.Position = headerOffset; | |||
stream.Seek(headerOffset, SeekOrigin.Begin); | |||
stream.WriteByte(fixedHeader); | |||
#if NET461 || NET452 || NETSTANDARD2_0 | |||
var buffer = stream.GetBuffer(); | |||
return new ArraySegment<byte>(buffer, headerOffset, (int)stream.Length - headerOffset); | |||
#else | |||
if (stream.TryGetBuffer(out var segment)) | |||
{ | |||
return new ArraySegment<byte>(segment.Array, headerOffset, segment.Count - headerOffset); | |||
} | |||
var buffer = stream.ToArray(); | |||
#endif | |||
return new ArraySegment<byte>(buffer, headerOffset, (int)stream.Length - headerOffset); | |||
return new ArraySegment<byte>(buffer, headerOffset, buffer.Length - headerOffset); | |||
#endif | |||
} | |||
} | |||
@@ -41,19 +41,18 @@ namespace MQTTnet.Serializer | |||
stream.Write(value, 0, length); | |||
} | |||
public static int EncodeRemainingLength(int length, MemoryStream stream) | |||
public static int WriteRemainingLength(int length, MemoryStream stream) | |||
{ | |||
// write the encoded remaining length right aligned on the 4 byte buffer | |||
if (length <= 0) | |||
{ | |||
stream.Seek(3, SeekOrigin.Current); | |||
stream.Seek(4, SeekOrigin.Begin); | |||
stream.WriteByte(0); | |||
return 1; | |||
} | |||
var buffer = new byte[4]; | |||
var offset = 0; | |||
var remainingLengthSize = 0; | |||
// Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. | |||
var x = length; | |||
@@ -66,15 +65,15 @@ namespace MQTTnet.Serializer | |||
encodedByte = encodedByte | 128; | |||
} | |||
buffer[offset] = (byte)encodedByte; | |||
buffer[remainingLengthSize] = (byte)encodedByte; | |||
offset++; | |||
remainingLengthSize++; | |||
} while (x > 0); | |||
stream.Seek(4 - offset, SeekOrigin.Current); | |||
stream.Write(buffer, 0, offset); | |||
stream.Seek(5 - remainingLengthSize, SeekOrigin.Begin); | |||
stream.Write(buffer, 0, remainingLengthSize); | |||
return offset; | |||
return remainingLengthSize; | |||
} | |||
} | |||
} |
@@ -51,12 +51,11 @@ namespace MQTTnet.Benchmarks | |||
{ | |||
var channel = new TestMqttChannel(headerStream); | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(headerStream), CancellationToken.None).GetAwaiter().GetResult(); | |||
var bodyLength = MqttPacketReader.ReadBodyLengthAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(_serializedPacket), (int)headerStream.Position, bodyLength)) | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(_serializedPacket), (int)headerStream.Position, header.RemainingLength)) | |||
{ | |||
_serializer.Deserialize(new ReceivedMqttPacket((byte)header, bodyStream)); | |||
_serializer.Deserialize(new ReceivedMqttPacket(header.Flags, bodyStream)); | |||
} | |||
} | |||
} | |||
@@ -2,6 +2,7 @@ | |||
using System.Threading; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Core.Internal; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Serializer; | |||
namespace MQTTnet.Core.Tests | |||
@@ -10,11 +11,10 @@ namespace MQTTnet.Core.Tests | |||
public class MqttPacketReaderTests | |||
{ | |||
[TestMethod] | |||
[ExpectedException(typeof(MqttCommunicationException))] | |||
public void MqttPacketReader_EmptyStream() | |||
{ | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); | |||
Assert.AreEqual(-1, header); | |||
MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); | |||
} | |||
} | |||
} |
@@ -419,11 +419,10 @@ namespace MQTTnet.Core.Tests | |||
{ | |||
var channel = new TestMqttChannel(headerStream); | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||
var bodyLength = MqttPacketReader.ReadBodyLengthAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, bodyLength)) | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | |||
{ | |||
var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket((byte)header, bodyStream)); | |||
var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket(header.Flags, bodyStream)); | |||
var buffer2 = serializer.Serialize(deserializedPacket); | |||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2))); | |||