@@ -1,6 +1,7 @@
using System;
using System.Buffers;
using System.IO;
using MQTTnet.Adapter;
using MQTTnet.Exceptions;
using MQTTnet.Packets;
using MQTTnet.Protocol;
@@ -10,37 +11,23 @@ namespace MQTTnet.AspNetCore
{
public static class ReaderExtensions
{
public static MqttPacketHeader ReadHeader(this ref ReadOnlySequence<byte> input)
{
if (input.Length < 2)
{
return null;
}
var fixedHeader = input.First.Span[0];
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4);
var bodyLength = ReadBodyLength(ref input);
return new MqttPacketHeader
{
FixedHeader = fixedHeader,
ControlPacketType = controlPacketType,
BodyLength = bodyLength
};
}
private static int ReadBodyLength(ref ReadOnlySequence<byte> input)
private static bool TryReadBodyLength(ref ReadOnlySequence<byte> input, out int result)
{
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html.
var multiplier = 1;
var value = 0;
byte encodedByte;
var index = 1;
result = 0;
var temp = input.Slice(0, Math.Min(5, input.Length)).GetArray();
do
{
if (index == temp.Length)
{
return false;
}
encodedByte = temp[index];
index++;
@@ -55,7 +42,8 @@ namespace MQTTnet.AspNetCore
input = input.Slice(index);
return value;
result = value;
return true;
}
@@ -75,17 +63,22 @@ namespace MQTTnet.AspNetCore
{
packet = null;
var copy = input;
var header = copy.ReadHeader();
if (header == null || copy.Length < header.BodyLength)
if (copy.Length < 2)
{
return false;
}
var fixedheader = copy.First.Span[0];
if (!TryReadBodyLength(ref copy, out var bodyLength))
{
return false;
}
input = copy.Slice(header.BodyLength);
var bodySlice = copy.Slice(0, header.B odyLength);
input = copy.Slice(b odyLength);
var bodySlice = copy.Slice(0, b odyLength);
using (var body = new MemoryStream(bodySlice.GetArray()))
{
packet = serializer.Deserialize(header, body);
packet = serializer.Deserialize(new ReceivedMqttPacket(fixed header, body) );
return true;
}
}
@@ -93,19 +86,24 @@ namespace MQTTnet.AspNetCore
public static bool TryDeserialize(this IMqttPacketSerializer serializer, in ReadOnlySequence<byte> input, out MqttBasePacket packet, out SequencePosition consumed, out SequencePosition observed)
{
packet = null;
consumed = input.Start;
observed = input.End;
var copy = input;
var header = copy.ReadHeader();
if (header == null || copy.Length < header.BodyLength)
if (copy.Length < 2)
{
return false;
}
var fixedheader = copy.First.Span[0];
if (!TryReadBodyLength(ref copy, out var bodyLength))
{
consumed = input.Start;
observed = input.End;
return false;
}
var bodySlice = copy.Slice(0, header.B odyLength);
var bodySlice = copy.Slice(0, b odyLength);
using (var body = new MemoryStream(bodySlice.GetArray()))
{
packet = serializer.Deserialize(header, body);
packet = serializer.Deserialize(new ReceivedMqttPacket(fixed header, body) );
consumed = bodySlice.End;
observed = bodySlice.End;
return true;