Bläddra i källkod

Covnert the body buffer to a memory stream which will grow automatically.

release/3.x.x
Christian 6 år sedan
förälder
incheckning
1415e4b878
7 ändrade filer med 56 tillägg och 41 borttagningar
  1. +37
    -28
      Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs
  2. +9
    -3
      Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs
  3. +2
    -1
      Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs
  4. +3
    -3
      Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs
  5. +2
    -3
      Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs
  6. +1
    -1
      Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs
  7. +2
    -2
      Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs

+ 37
- 28
Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs Visa fil

@@ -17,8 +17,7 @@ namespace MQTTnet.Adapter
public class MqttChannelAdapter : IMqttChannelAdapter
{
private const uint ErrorOperationAborted = 0x800703E3;

private static readonly byte[] EmptyBody = new byte[0];
private const int ReadBufferSize = 4096; // TODO: Move buffer size to config

private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
private readonly IMqttNetLogger _logger;
@@ -91,28 +90,35 @@ namespace MQTTnet.Adapter
MqttBasePacket packet = null;
await ExecuteAndWrapExceptionAsync(async () =>
{
ReceivedMqttPacket receivedMqttPacket;
if (timeout > TimeSpan.Zero)
{
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false);
}
else
ReceivedMqttPacket receivedMqttPacket = null;
try
{
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false);
}
if (timeout > TimeSpan.Zero)
{
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false);
}
else
{
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false);
}

if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}

packet = PacketSerializer.Deserialize(receivedMqttPacket.Header, receivedMqttPacket.Body);
if (packet == null)
{
throw new MqttProtocolViolationException("Received malformed packet.");
}

packet = PacketSerializer.Deserialize(receivedMqttPacket.Header, receivedMqttPacket.Body);
if (packet == null)
_logger.Trace<MqttChannelAdapter>("RX <<< {0}", packet);
}
finally
{
throw new MqttProtocolViolationException("Received malformed packet.");
receivedMqttPacket?.Dispose();
}

_logger.Trace<MqttChannelAdapter>("RX <<< {0}", packet);
}).ConfigureAwait(false);

return packet;
@@ -120,7 +126,7 @@ namespace MQTTnet.Adapter

private static async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream, CancellationToken cancellationToken)
{
var header = await MqttPacketReader.ReadHeaderFromSourceAsync(stream, cancellationToken).ConfigureAwait(false);
var header = await MqttPacketReader.ReadHeaderAsync(stream, cancellationToken).ConfigureAwait(false);
if (header == null)
{
return null;
@@ -128,23 +134,26 @@ namespace MQTTnet.Adapter

if (header.BodyLength == 0)
{
return new ReceivedMqttPacket(header, EmptyBody);
return new ReceivedMqttPacket(header, new MemoryStream(new byte[0], false));
}

var body = new byte[header.BodyLength];
var offset = 0;
do
var body = header.BodyLength <= ReadBufferSize ? new MemoryStream(new byte[header.BodyLength]) : new MemoryStream();
var buffer = new byte[ReadBufferSize];
while (body.Length < header.BodyLength)
{
var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset, cancellationToken).ConfigureAwait(false);
var readBytesCount = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
// Check if the client closed the connection before sending the full body.
if (readBytesCount == 0)
{
throw new MqttCommunicationException("Connection closed while reading remaining packet body.");
}

offset += readBytesCount;
} while (offset < header.BodyLength);
body.Write(buffer, 0, readBytesCount);
}

body.Seek(0L, SeekOrigin.Begin);
return new ReceivedMqttPacket(header, body);
}


+ 9
- 3
Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs Visa fil

@@ -1,11 +1,12 @@
using System;
using System.IO;
using MQTTnet.Packets;

namespace MQTTnet.Adapter
{
public class ReceivedMqttPacket
public sealed class ReceivedMqttPacket : IDisposable
{
public ReceivedMqttPacket(MqttPacketHeader header, byte[] body)
public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body)
{
Header = header ?? throw new ArgumentNullException(nameof(header));
Body = body ?? throw new ArgumentNullException(nameof(body));
@@ -13,6 +14,11 @@ namespace MQTTnet.Adapter

public MqttPacketHeader Header { get; }

public byte[] Body { get; }
public MemoryStream Body { get; }

public void Dispose()
{
Body?.Dispose();
}
}
}

+ 2
- 1
Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs Visa fil

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using MQTTnet.Packets;

namespace MQTTnet.Serializer
@@ -10,6 +11,6 @@ namespace MQTTnet.Serializer

ICollection<ArraySegment<byte>> Serialize(MqttBasePacket mqttPacket);

MqttBasePacket Deserialize(MqttPacketHeader header, byte[] body);
MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body);
}
}

+ 3
- 3
Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs Visa fil

@@ -22,7 +22,7 @@ namespace MQTTnet.Serializer

public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength;

public static async Task<MqttPacketHeader> ReadHeaderFromSourceAsync(Stream stream, CancellationToken cancellationToken)
public static async Task<MqttPacketHeader> ReadHeaderAsync(Stream stream, CancellationToken cancellationToken)
{
// Wait for the next package which starts with the header. At this point there will probably
// some large delay and thus the thread should be put back to the pool (await). So ReadByte()
@@ -36,7 +36,7 @@ namespace MQTTnet.Serializer

var fixedHeader = buffer[0];
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4);
var bodyLength = await ReadBodyLengthFromSourceAsync(stream, cancellationToken).ConfigureAwait(false);
var bodyLength = await ReadBodyLengthAsync(stream, cancellationToken).ConfigureAwait(false);

return new MqttPacketHeader
{
@@ -84,7 +84,7 @@ namespace MQTTnet.Serializer
return ReadBytes(_header.BodyLength - (int)BaseStream.Position);
}

private static async Task<int> ReadBodyLengthFromSourceAsync(Stream stream, CancellationToken cancellationToken)
private static async Task<int> ReadBodyLengthAsync(Stream stream, CancellationToken cancellationToken)
{
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html.
var multiplier = 1;


+ 2
- 3
Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs Visa fil

@@ -42,13 +42,12 @@ namespace MQTTnet.Serializer
}
}

public MqttBasePacket Deserialize(MqttPacketHeader header, byte[] body)
public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body)
{
if (header == null) throw new ArgumentNullException(nameof(header));
if (body == null) throw new ArgumentNullException(nameof(body));

using (var bodyStream = new MemoryStream(body))
using (var reader = new MqttPacketReader(header, bodyStream))
using (var reader = new MqttPacketReader(header, body))
{
return Deserialize(header, reader);
}


+ 1
- 1
Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs Visa fil

@@ -12,7 +12,7 @@ namespace MQTTnet.Core.Tests
public void MqttPacketReader_EmptyStream()
{
var memStream = new MemoryStream();
var header = MqttPacketReader.ReadHeaderFromSourceAsync(memStream, CancellationToken.None).GetAwaiter().GetResult();
var header = MqttPacketReader.ReadHeaderAsync(memStream, CancellationToken.None).GetAwaiter().GetResult();

Assert.IsNull(header);
}


+ 2
- 2
Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs Visa fil

@@ -405,11 +405,11 @@ namespace MQTTnet.Core.Tests

using (var headerStream = new MemoryStream(Join(buffer1)))
{
var header = MqttPacketReader.ReadHeaderFromSourceAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult();
var header = MqttPacketReader.ReadHeaderAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult();

using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.BodyLength))
{
var deserializedPacket = serializer.Deserialize(header, bodyStream.ToArray());
var deserializedPacket = serializer.Deserialize(header, bodyStream);
var buffer2 = serializer.Serialize(deserializedPacket);

Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2)));


Laddar…
Avbryt
Spara