@@ -49,7 +49,7 @@ namespace MQTTnet.AspNetCore | |||
var formatter = new MqttPacketFormatterAdapter(writer); | |||
var channel = new MqttWebSocketChannel(webSocket, endpoint, isSecureConnection, clientCertificate); | |||
using (var channelAdapter = new MqttChannelAdapter(channel, formatter, _rootLogger)) | |||
using (var channelAdapter = new MqttChannelAdapter(channel, formatter, null, _rootLogger)) | |||
{ | |||
await clientHandler(channelAdapter).ConfigureAwait(false); | |||
} | |||
@@ -7,7 +7,7 @@ using System; | |||
namespace MQTTnet.Extensions.WebSocket4Net | |||
{ | |||
public class WebSocket4NetMqttClientAdapterFactory : IMqttClientAdapterFactory | |||
public sealed class WebSocket4NetMqttClientAdapterFactory : IMqttClientAdapterFactory | |||
{ | |||
readonly IMqttNetLogger _logger; | |||
@@ -23,14 +23,22 @@ namespace MQTTnet.Extensions.WebSocket4Net | |||
switch (options.ChannelOptions) | |||
{ | |||
case MqttClientTcpOptions _: | |||
{ | |||
return new MqttChannelAdapter(new MqttTcpChannel(options), new MqttPacketFormatterAdapter(options.ProtocolVersion), _logger); | |||
} | |||
{ | |||
return new MqttChannelAdapter( | |||
new MqttTcpChannel(options), | |||
new MqttPacketFormatterAdapter(options.ProtocolVersion), | |||
options.PacketInspector, | |||
_logger); | |||
} | |||
case MqttClientWebSocketOptions webSocketOptions: | |||
{ | |||
return new MqttChannelAdapter(new WebSocket4NetMqttChannel(options, webSocketOptions), new MqttPacketFormatterAdapter(options.ProtocolVersion), _logger); | |||
} | |||
{ | |||
return new MqttChannelAdapter( | |||
new WebSocket4NetMqttChannel(options, webSocketOptions), | |||
new MqttPacketFormatterAdapter(options.ProtocolVersion), | |||
options.PacketInspector, | |||
_logger); | |||
} | |||
default: | |||
{ | |||
@@ -1,7 +1,7 @@ | |||
<Project Sdk="Microsoft.NET.Sdk.Web"> | |||
<PropertyGroup> | |||
<TargetFrameworks>netcoreapp3.1;net5.0</TargetFrameworks> | |||
<TargetFrameworks>net5.0</TargetFrameworks> | |||
<AspNetCoreHostingModel>InProcess</AspNetCoreHostingModel> | |||
<AssemblyName>MQTTnet.Server</AssemblyName> | |||
<RootNamespace>MQTTnet.Server</RootNamespace> | |||
@@ -11,6 +11,7 @@ using System.Runtime.InteropServices; | |||
using System.Security.Cryptography.X509Certificates; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Diagnostics.PacketInspection; | |||
namespace MQTTnet.Adapter | |||
{ | |||
@@ -19,26 +20,26 @@ namespace MQTTnet.Adapter | |||
const uint ErrorOperationAborted = 0x800703E3; | |||
const int ReadBufferSize = 4096; | |||
readonly byte[] _singleByteBuffer = new byte[1]; | |||
readonly byte[] _fixedHeaderBuffer = new byte[2]; | |||
readonly MqttPacketInspectorHandler _packetInspectorHandler; | |||
readonly IMqttNetScopedLogger _logger; | |||
readonly IMqttChannel _channel; | |||
readonly MqttPacketReader _packetReader; | |||
readonly byte[] _fixedHeaderBuffer = new byte[2]; | |||
readonly AsyncLock _syncRoot = new AsyncLock(); | |||
long _bytesReceived; | |||
long _bytesSent; | |||
public MqttChannelAdapter(IMqttChannel channel, MqttPacketFormatterAdapter packetFormatterAdapter, IMqttNetLogger logger) | |||
public MqttChannelAdapter(IMqttChannel channel, MqttPacketFormatterAdapter packetFormatterAdapter, IMqttPacketInspector packetInspector, IMqttNetLogger logger) | |||
{ | |||
if (logger == null) throw new ArgumentNullException(nameof(logger)); | |||
_channel = channel ?? throw new ArgumentNullException(nameof(channel)); | |||
PacketFormatterAdapter = packetFormatterAdapter ?? throw new ArgumentNullException(nameof(packetFormatterAdapter)); | |||
_packetReader = new MqttPacketReader(_channel); | |||
_packetInspectorHandler = new MqttPacketInspectorHandler(packetInspector, logger); | |||
if (logger == null) throw new ArgumentNullException(nameof(logger)); | |||
_logger = logger.CreateScopedLogger(nameof(MqttChannelAdapter)); | |||
} | |||
@@ -124,6 +125,7 @@ namespace MQTTnet.Adapter | |||
try | |||
{ | |||
var packetData = PacketFormatterAdapter.Encode(packet); | |||
_packetInspectorHandler.BeginSendPacket(packetData); | |||
await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false); | |||
@@ -154,12 +156,16 @@ namespace MQTTnet.Adapter | |||
try | |||
{ | |||
_packetInspectorHandler.BeginReceivePacket(); | |||
var receivedPacket = await ReceiveAsync(cancellationToken).ConfigureAwait(false); | |||
if (receivedPacket == null || cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
_packetInspectorHandler.EndReceivePacket(); | |||
Interlocked.Add(ref _bytesSent, receivedPacket.TotalLength); | |||
if (PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.Unknown) | |||
@@ -215,7 +221,12 @@ namespace MQTTnet.Adapter | |||
async Task<ReceivedMqttPacket> ReceiveAsync(CancellationToken cancellationToken) | |||
{ | |||
var readFixedHeaderResult = await _packetReader.ReadFixedHeaderAsync(_fixedHeaderBuffer, cancellationToken).ConfigureAwait(false); | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
var readFixedHeaderResult = await ReadFixedHeaderAsync(cancellationToken).ConfigureAwait(false); | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
@@ -234,7 +245,7 @@ namespace MQTTnet.Adapter | |||
var fixedHeader = readFixedHeaderResult.FixedHeader; | |||
if (fixedHeader.RemainingLength == 0) | |||
{ | |||
return new ReceivedMqttPacket(fixedHeader.Flags, null, 2); | |||
return new ReceivedMqttPacket(fixedHeader.Flags, new MqttPacketBodyReader(new byte[0], 0, 0), 2); | |||
} | |||
var bodyLength = fixedHeader.RemainingLength; | |||
@@ -266,6 +277,8 @@ namespace MQTTnet.Adapter | |||
bodyOffset += readBytes; | |||
} while (bodyOffset < bodyLength); | |||
_packetInspectorHandler.FillReceiveBuffer(body); | |||
var bodyReader = new MqttPacketBodyReader(body, 0, bodyLength); | |||
return new ReceivedMqttPacket(fixedHeader.Flags, bodyReader, fixedHeader.TotalLength); | |||
} | |||
@@ -275,11 +288,110 @@ namespace MQTTnet.Adapter | |||
} | |||
} | |||
async Task<ReadFixedHeaderResult> ReadFixedHeaderAsync(CancellationToken cancellationToken) | |||
{ | |||
// The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. | |||
// So in all cases at least 2 bytes must be read for a complete MQTT packet. | |||
var buffer = _fixedHeaderBuffer; | |||
var totalBytesRead = 0; | |||
while (totalBytesRead < buffer.Length) | |||
{ | |||
var bytesRead = await _channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
if (bytesRead == 0) | |||
{ | |||
return new ReadFixedHeaderResult | |||
{ | |||
ConnectionClosed = true | |||
}; | |||
} | |||
totalBytesRead += bytesRead; | |||
} | |||
_packetInspectorHandler.FillReceiveBuffer(buffer); | |||
var hasRemainingLength = buffer[1] != 0; | |||
if (!hasRemainingLength) | |||
{ | |||
return new ReadFixedHeaderResult | |||
{ | |||
FixedHeader = new MqttFixedHeader(buffer[0], 0, totalBytesRead) | |||
}; | |||
} | |||
var bodyLength = await ReadBodyLengthAsync(buffer[1], cancellationToken).ConfigureAwait(false); | |||
if (!bodyLength.HasValue) | |||
{ | |||
return new ReadFixedHeaderResult | |||
{ | |||
ConnectionClosed = true | |||
}; | |||
} | |||
totalBytesRead += bodyLength.Value; | |||
return new ReadFixedHeaderResult | |||
{ | |||
FixedHeader = new MqttFixedHeader(buffer[0], bodyLength.Value, totalBytesRead) | |||
}; | |||
} | |||
async Task<int?> ReadBodyLengthAsync(byte initialEncodedByte, CancellationToken cancellationToken) | |||
{ | |||
var offset = 0; | |||
var multiplier = 128; | |||
var value = initialEncodedByte & 127; | |||
int encodedByte = initialEncodedByte; | |||
while ((encodedByte & 128) != 0) | |||
{ | |||
offset++; | |||
if (offset > 3) | |||
{ | |||
throw new MqttProtocolViolationException("Remaining length is invalid."); | |||
} | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
var readCount = await _channel.ReadAsync(_singleByteBuffer, 0, 1, cancellationToken).ConfigureAwait(false); | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
if (readCount == 0) | |||
{ | |||
return null; | |||
} | |||
_packetInspectorHandler.FillReceiveBuffer(_singleByteBuffer); | |||
encodedByte = _singleByteBuffer[0]; | |||
value += (encodedByte & 127) * multiplier; | |||
multiplier *= 128; | |||
} | |||
return value; | |||
} | |||
static bool IsWrappedException(Exception exception) | |||
{ | |||
return exception is OperationCanceledException || | |||
exception is MqttCommunicationTimedOutException || | |||
exception is MqttCommunicationException; | |||
exception is MqttCommunicationException || | |||
exception is MqttProtocolViolationException; | |||
} | |||
static void WrapAndThrowException(Exception exception) | |||
@@ -295,7 +407,7 @@ namespace MQTTnet.Adapter | |||
{ | |||
throw new OperationCanceledException(); | |||
} | |||
if (socketException.SocketErrorCode == SocketError.ConnectionAborted) | |||
{ | |||
throw new MqttCommunicationException(socketException); | |||
@@ -0,0 +1,84 @@ | |||
using System; | |||
using System.IO; | |||
using System.Linq; | |||
using MQTTnet.Diagnostics; | |||
using MQTTnet.Diagnostics.PacketInspection; | |||
namespace MQTTnet.Adapter | |||
{ | |||
public sealed class MqttPacketInspectorHandler | |||
{ | |||
readonly MemoryStream _receivedPacketBuffer; | |||
readonly IMqttPacketInspector _packetInspector; | |||
readonly IMqttNetScopedLogger _logger; | |||
public MqttPacketInspectorHandler(IMqttPacketInspector packetInspector, IMqttNetLogger logger) | |||
{ | |||
_packetInspector = packetInspector; | |||
if (packetInspector != null) | |||
{ | |||
_receivedPacketBuffer = new MemoryStream(); | |||
} | |||
if (logger == null) throw new ArgumentNullException(nameof(logger)); | |||
_logger = logger.CreateScopedLogger(nameof(MqttPacketInspectorHandler)); | |||
} | |||
public void BeginReceivePacket() | |||
{ | |||
_receivedPacketBuffer?.SetLength(0); | |||
} | |||
public void EndReceivePacket() | |||
{ | |||
if (_packetInspector == null) | |||
{ | |||
return; | |||
} | |||
var buffer = _receivedPacketBuffer.ToArray(); | |||
_receivedPacketBuffer.SetLength(0); | |||
InspectPacket(buffer, MqttPacketFlowDirection.Inbound); | |||
} | |||
public void BeginSendPacket(ArraySegment<byte> buffer) | |||
{ | |||
if (_packetInspector == null) | |||
{ | |||
return; | |||
} | |||
// Create a copy of the actual packet so that the inspector gets no access | |||
// to the internal buffers. This is waste of memory but this feature is only | |||
// intended for debugging etc. so that this is OK. | |||
var bufferCopy = buffer.ToArray(); | |||
InspectPacket(bufferCopy, MqttPacketFlowDirection.Outbound); | |||
} | |||
public void FillReceiveBuffer(byte[] buffer) | |||
{ | |||
_receivedPacketBuffer?.Write(buffer, 0, buffer.Length); | |||
} | |||
void InspectPacket(byte[] buffer, MqttPacketFlowDirection direction) | |||
{ | |||
try | |||
{ | |||
var context = new ProcessMqttPacketContext | |||
{ | |||
Buffer = buffer, | |||
Direction = direction | |||
}; | |||
_packetInspector.ProcessMqttPacket(context); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while inspecting packet."); | |||
} | |||
} | |||
} | |||
} |
@@ -1,19 +1,20 @@ | |||
using MQTTnet.Formatter; | |||
using System; | |||
using MQTTnet.Formatter; | |||
namespace MQTTnet.Adapter | |||
{ | |||
public sealed class ReceivedMqttPacket | |||
{ | |||
public ReceivedMqttPacket(byte fixedHeader, IMqttPacketBodyReader body, int totalLength) | |||
public ReceivedMqttPacket(byte fixedHeader, IMqttPacketBodyReader bodyReader, int totalLength) | |||
{ | |||
FixedHeader = fixedHeader; | |||
Body = body; | |||
BodyReader = bodyReader ?? throw new ArgumentNullException(nameof(bodyReader)); | |||
TotalLength = totalLength; | |||
} | |||
public byte FixedHeader { get; } | |||
public IMqttPacketBodyReader Body { get; } | |||
public IMqttPacketBodyReader BodyReader { get; } | |||
public int TotalLength { get; } | |||
} | |||
@@ -3,6 +3,7 @@ using MQTTnet.Formatter; | |||
using MQTTnet.Packets; | |||
using System; | |||
using System.Collections.Generic; | |||
using MQTTnet.Diagnostics.PacketInspection; | |||
namespace MQTTnet.Client.Options | |||
{ | |||
@@ -29,5 +30,7 @@ namespace MQTTnet.Client.Options | |||
uint? SessionExpiryInterval { get; } | |||
ushort? TopicAliasMaximum { get; } | |||
List<MqttUserProperty> UserProperties { get; set; } | |||
IMqttPacketInspector PacketInspector { get; set; } | |||
} | |||
} |
@@ -3,6 +3,7 @@ using MQTTnet.Formatter; | |||
using MQTTnet.Packets; | |||
using System; | |||
using System.Collections.Generic; | |||
using MQTTnet.Diagnostics.PacketInspection; | |||
namespace MQTTnet.Client.Options | |||
{ | |||
@@ -31,5 +32,7 @@ namespace MQTTnet.Client.Options | |||
public uint? SessionExpiryInterval { get; set; } | |||
public ushort? TopicAliasMaximum { get; set; } | |||
public List<MqttUserProperty> UserProperties { get; set; } | |||
public IMqttPacketInspector PacketInspector { get; set; } | |||
} | |||
} |
@@ -5,6 +5,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using MQTTnet.Diagnostics.PacketInspection; | |||
namespace MQTTnet.Client.Options | |||
{ | |||
@@ -256,6 +257,12 @@ namespace MQTTnet.Client.Options | |||
return this; | |||
} | |||
public MqttClientOptionsBuilder WithPacketInspector(IMqttPacketInspector packetInspector) | |||
{ | |||
_options.PacketInspector = packetInspector; | |||
return this; | |||
} | |||
public IMqttClientOptions Build() | |||
{ | |||
if (_tcpOptions == null && _webSocketOptions == null) | |||
@@ -0,0 +1,7 @@ | |||
namespace MQTTnet.Diagnostics.PacketInspection | |||
{ | |||
public interface IMqttPacketInspector | |||
{ | |||
void ProcessMqttPacket(ProcessMqttPacketContext context); | |||
} | |||
} |
@@ -0,0 +1,9 @@ | |||
namespace MQTTnet.Diagnostics.PacketInspection | |||
{ | |||
public enum MqttPacketFlowDirection | |||
{ | |||
Inbound, | |||
Outbound | |||
} | |||
} |
@@ -0,0 +1,9 @@ | |||
namespace MQTTnet.Diagnostics.PacketInspection | |||
{ | |||
public sealed class ProcessMqttPacketContext | |||
{ | |||
public MqttPacketFlowDirection Direction { get; set; } | |||
public byte[] Buffer { get; set; } | |||
} | |||
} |
@@ -11,7 +11,7 @@ namespace MQTTnet.Formatter | |||
public sealed class MqttPacketFormatterAdapter | |||
{ | |||
IMqttPacketFormatter _formatter; | |||
public MqttPacketFormatterAdapter(MqttProtocolVersion protocolVersion) | |||
: this(protocolVersion, new MqttPacketWriter()) | |||
{ | |||
@@ -26,7 +26,7 @@ namespace MQTTnet.Formatter | |||
public MqttPacketFormatterAdapter(IMqttPacketWriter writer) | |||
{ | |||
Writer = writer; | |||
} | |||
} | |||
public MqttProtocolVersion ProtocolVersion { get; private set; } = MqttProtocolVersion.Unknown; | |||
@@ -39,7 +39,7 @@ namespace MQTTnet.Formatter | |||
return _formatter.DataConverter; | |||
} | |||
} | |||
public IMqttPacketWriter Writer { get; } | |||
public ArraySegment<byte> Encode(MqttBasePacket packet) | |||
@@ -69,10 +69,10 @@ namespace MQTTnet.Formatter | |||
{ | |||
var protocolVersion = ParseProtocolVersion(receivedMqttPacket); | |||
// Reset the position of the stream because the protocol version is part of | |||
// Reset the position of the stream because the protocol version is part of | |||
// the regular CONNECT packet. So it will not properly deserialized if this | |||
// data is missing. | |||
receivedMqttPacket.Body.Seek(0); | |||
receivedMqttPacket.BodyReader.Seek(0); | |||
UseProtocolVersion(protocolVersion); | |||
} | |||
@@ -83,7 +83,7 @@ namespace MQTTnet.Formatter | |||
{ | |||
throw new InvalidOperationException("MQTT protocol version is invalid."); | |||
} | |||
switch (protocolVersion) | |||
{ | |||
case MqttProtocolVersion.V500: | |||
@@ -120,7 +120,7 @@ namespace MQTTnet.Formatter | |||
{ | |||
if (receivedMqttPacket == null) throw new ArgumentNullException(nameof(receivedMqttPacket)); | |||
if (receivedMqttPacket.Body.Length < 7) | |||
if (receivedMqttPacket.BodyReader.Length < 7) | |||
{ | |||
// 2 byte protocol name length | |||
// at least 4 byte protocol name | |||
@@ -128,8 +128,8 @@ namespace MQTTnet.Formatter | |||
throw new MqttProtocolViolationException("CONNECT packet must have at least 7 bytes."); | |||
} | |||
var protocolName = receivedMqttPacket.Body.ReadStringWithLengthPrefix(); | |||
var protocolLevel = receivedMqttPacket.Body.ReadByte(); | |||
var protocolName = receivedMqttPacket.BodyReader.ReadStringWithLengthPrefix(); | |||
var protocolLevel = receivedMqttPacket.BodyReader.ReadByte(); | |||
if (protocolName == "MQTT") | |||
{ | |||
@@ -1,114 +0,0 @@ | |||
using System; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Channel; | |||
using MQTTnet.Exceptions; | |||
namespace MQTTnet.Formatter | |||
{ | |||
public sealed class MqttPacketReader | |||
{ | |||
readonly byte[] _singleByteBuffer = new byte[1]; | |||
readonly IMqttChannel _channel; | |||
public MqttPacketReader(IMqttChannel channel) | |||
{ | |||
_channel = channel ?? throw new ArgumentNullException(nameof(channel)); | |||
} | |||
public async Task<ReadFixedHeaderResult> ReadFixedHeaderAsync(byte[] fixedHeaderBuffer, CancellationToken cancellationToken) | |||
{ | |||
// The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. | |||
// So in all cases at least 2 bytes must be read for a complete MQTT packet. | |||
var buffer = fixedHeaderBuffer; | |||
var totalBytesRead = 0; | |||
while (totalBytesRead < buffer.Length) | |||
{ | |||
var bytesRead = await _channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
if (bytesRead == 0) | |||
{ | |||
return new ReadFixedHeaderResult | |||
{ | |||
ConnectionClosed = true | |||
}; | |||
} | |||
totalBytesRead += bytesRead; | |||
} | |||
var hasRemainingLength = buffer[1] != 0; | |||
if (!hasRemainingLength) | |||
{ | |||
return new ReadFixedHeaderResult | |||
{ | |||
FixedHeader = new MqttFixedHeader(buffer[0], 0, totalBytesRead) | |||
}; | |||
} | |||
var bodyLength = await ReadBodyLengthAsync(buffer[1], cancellationToken).ConfigureAwait(false); | |||
if (!bodyLength.HasValue) | |||
{ | |||
return new ReadFixedHeaderResult | |||
{ | |||
ConnectionClosed = true | |||
}; | |||
} | |||
totalBytesRead += bodyLength.Value; | |||
return new ReadFixedHeaderResult | |||
{ | |||
FixedHeader = new MqttFixedHeader(buffer[0], bodyLength.Value, totalBytesRead) | |||
}; | |||
} | |||
async Task<int?> ReadBodyLengthAsync(byte initialEncodedByte, CancellationToken cancellationToken) | |||
{ | |||
var offset = 0; | |||
var multiplier = 128; | |||
var value = initialEncodedByte & 127; | |||
int encodedByte = initialEncodedByte; | |||
while ((encodedByte & 128) != 0) | |||
{ | |||
offset++; | |||
if (offset > 3) | |||
{ | |||
throw new MqttProtocolViolationException("Remaining length is invalid."); | |||
} | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
var readCount = await _channel.ReadAsync(_singleByteBuffer, 0, 1, cancellationToken).ConfigureAwait(false); | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return null; | |||
} | |||
if (readCount == 0) | |||
{ | |||
return null; | |||
} | |||
encodedByte = _singleByteBuffer[0]; | |||
value += (encodedByte & 127) * multiplier; | |||
multiplier *= 128; | |||
} | |||
return value; | |||
} | |||
} | |||
} |
@@ -63,20 +63,20 @@ namespace MQTTnet.Formatter.V3 | |||
switch ((MqttControlPacketType)controlPacketType) | |||
{ | |||
case MqttControlPacketType.Connect: return DecodeConnectPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.ConnAck: return DecodeConnAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Connect: return DecodeConnectPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.ConnAck: return DecodeConnAckPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.Disconnect: return DisconnectPacket; | |||
case MqttControlPacketType.Publish: return DecodePublishPacket(receivedMqttPacket); | |||
case MqttControlPacketType.PubAck: return DecodePubAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubRec: return DecodePubRecPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubRel: return DecodePubRelPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubComp: return DecodePubCompPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubAck: return DecodePubAckPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PubRec: return DecodePubRecPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PubRel: return DecodePubRelPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PubComp: return DecodePubCompPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PingReq: return PingReqPacket; | |||
case MqttControlPacketType.PingResp: return PingRespPacket; | |||
case MqttControlPacketType.Subscribe: return DecodeSubscribePacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.SubAck: return DecodeSubAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Unsubscibe: return DecodeUnsubscribePacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.UnsubAck: return DecodeUnsubAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Subscribe: return DecodeSubscribePacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.SubAck: return DecodeSubAckPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.Unsubscibe: return DecodeUnsubscribePacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.UnsubAck: return DecodeUnsubAckPacket(receivedMqttPacket.BodyReader); | |||
default: throw new MqttProtocolViolationException($"Packet type ({controlPacketType}) not supported."); | |||
} | |||
@@ -202,18 +202,18 @@ namespace MQTTnet.Formatter.V3 | |||
static MqttBasePacket DecodePublishPacket(ReceivedMqttPacket receivedMqttPacket) | |||
{ | |||
ThrowIfBodyIsEmpty(receivedMqttPacket.Body); | |||
ThrowIfBodyIsEmpty(receivedMqttPacket.BodyReader); | |||
var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; | |||
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); | |||
var dup = (receivedMqttPacket.FixedHeader & 0x8) > 0; | |||
var topic = receivedMqttPacket.Body.ReadStringWithLengthPrefix(); | |||
var topic = receivedMqttPacket.BodyReader.ReadStringWithLengthPrefix(); | |||
ushort packetIdentifier = 0; | |||
if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) | |||
{ | |||
packetIdentifier = receivedMqttPacket.Body.ReadTwoByteInteger(); | |||
packetIdentifier = receivedMqttPacket.BodyReader.ReadTwoByteInteger(); | |||
} | |||
var packet = new MqttPublishPacket | |||
@@ -225,9 +225,9 @@ namespace MQTTnet.Formatter.V3 | |||
Dup = dup | |||
}; | |||
if (!receivedMqttPacket.Body.EndOfStream) | |||
if (!receivedMqttPacket.BodyReader.EndOfStream) | |||
{ | |||
packet.Payload = receivedMqttPacket.Body.ReadRemainingData(); | |||
packet.Payload = receivedMqttPacket.BodyReader.ReadRemainingData(); | |||
} | |||
return packet; | |||
@@ -25,21 +25,21 @@ namespace MQTTnet.Formatter.V5 | |||
switch ((MqttControlPacketType)controlPacketType) | |||
{ | |||
case MqttControlPacketType.Connect: return DecodeConnectPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.ConnAck: return DecodeConnAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Disconnect: return DecodeDisconnectPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Publish: return DecodePublishPacket(receivedMqttPacket.FixedHeader, receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubAck: return DecodePubAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubRec: return DecodePubRecPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubRel: return DecodePubRelPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.PubComp: return DecodePubCompPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Connect: return DecodeConnectPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.ConnAck: return DecodeConnAckPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.Disconnect: return DecodeDisconnectPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.Publish: return DecodePublishPacket(receivedMqttPacket.FixedHeader, receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PubAck: return DecodePubAckPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PubRec: return DecodePubRecPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PubRel: return DecodePubRelPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PubComp: return DecodePubCompPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.PingReq: return DecodePingReqPacket(); | |||
case MqttControlPacketType.PingResp: return DecodePingRespPacket(); | |||
case MqttControlPacketType.Subscribe: return DecodeSubscribePacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.SubAck: return DecodeSubAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Unsubscibe: return DecodeUnsubscribePacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.UnsubAck: return DecodeUnsubAckPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Auth: return DecodeAuthPacket(receivedMqttPacket.Body); | |||
case MqttControlPacketType.Subscribe: return DecodeSubscribePacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.SubAck: return DecodeSubAckPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.Unsubscibe: return DecodeUnsubscribePacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.UnsubAck: return DecodeUnsubAckPacket(receivedMqttPacket.BodyReader); | |||
case MqttControlPacketType.Auth: return DecodeAuthPacket(receivedMqttPacket.BodyReader); | |||
default: throw new MqttProtocolViolationException($"Packet type ({controlPacketType}) not supported."); | |||
} | |||
@@ -3,6 +3,7 @@ using MQTTnet.Client.Options; | |||
using MQTTnet.Diagnostics; | |||
using MQTTnet.Formatter; | |||
using System; | |||
using MQTTnet.Channel; | |||
namespace MQTTnet.Implementations | |||
{ | |||
@@ -19,16 +20,19 @@ namespace MQTTnet.Implementations | |||
{ | |||
if (options == null) throw new ArgumentNullException(nameof(options)); | |||
IMqttChannel channel; | |||
switch (options.ChannelOptions) | |||
{ | |||
case MqttClientTcpOptions _: | |||
{ | |||
return new MqttChannelAdapter(new MqttTcpChannel(options), new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttPacketWriter()), _logger); | |||
channel = new MqttTcpChannel(options); | |||
break; | |||
} | |||
case MqttClientWebSocketOptions webSocketOptions: | |||
{ | |||
return new MqttChannelAdapter(new MqttWebSocketChannel(webSocketOptions), new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttPacketWriter()), _logger); | |||
channel = new MqttWebSocketChannel(webSocketOptions); | |||
break; | |||
} | |||
default: | |||
@@ -36,6 +40,9 @@ namespace MQTTnet.Implementations | |||
throw new NotSupportedException(); | |||
} | |||
} | |||
var packetFormatterAdapter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttPacketWriter()); | |||
return new MqttChannelAdapter(channel, packetFormatterAdapter, options.PacketInspector, _logger); | |||
} | |||
} | |||
} |
@@ -77,7 +77,7 @@ namespace MQTTnet.Implementations | |||
cancellationToken.ThrowIfCancellationRequested(); | |||
var networkStream = socket.GetStream(); | |||
if (_tcpOptions.TlsOptions?.UseTls == true) | |||
{ | |||
var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); | |||
@@ -94,7 +94,7 @@ namespace MQTTnet.Implementations | |||
}; | |||
await sslStream.AuthenticateAsClientAsync(sslOptions, cancellationToken).ConfigureAwait(false); | |||
#else | |||
#else | |||
await sslStream.AuthenticateAsClientAsync(_tcpOptions.Server, LoadCertificates(), _tcpOptions.TlsOptions.SslProtocol, !_tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); | |||
#endif | |||
} | |||
@@ -89,7 +89,7 @@ namespace MQTTnet.Implementations | |||
} | |||
} | |||
using (var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(args.Socket, clientCertificate, _options), new MqttPacketFormatterAdapter(new MqttPacketWriter()), _rootLogger)) | |||
using (var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(args.Socket, clientCertificate, _options), new MqttPacketFormatterAdapter(new MqttPacketWriter()), null, _rootLogger)) | |||
{ | |||
await clientHandler(clientAdapter).ConfigureAwait(false); | |||
} | |||
@@ -178,7 +178,11 @@ namespace MQTTnet.Implementations | |||
var clientHandler = ClientHandler; | |||
if (clientHandler != null) | |||
{ | |||
using (var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(stream, remoteEndPoint, clientCertificate), new MqttPacketFormatterAdapter(new MqttPacketWriter()), _rootLogger)) | |||
using (var clientAdapter = new MqttChannelAdapter( | |||
new MqttTcpChannel(stream, remoteEndPoint, clientCertificate), | |||
new MqttPacketFormatterAdapter(new MqttPacketWriter()), | |||
null, | |||
_rootLogger)) | |||
{ | |||
await clientHandler(clientAdapter).ConfigureAwait(false); | |||
} | |||
@@ -6,7 +6,7 @@ using MQTTnet.Channel; | |||
namespace MQTTnet.Internal | |||
{ | |||
public class TestMqttChannel : IMqttChannel | |||
public sealed class TestMqttChannel : IMqttChannel | |||
{ | |||
readonly MemoryStream _stream; | |||
@@ -15,6 +15,11 @@ namespace MQTTnet.Internal | |||
_stream = stream; | |||
} | |||
public TestMqttChannel(byte[] buffer) | |||
{ | |||
_stream = new MemoryStream(buffer); | |||
} | |||
public string Endpoint { get; } = "<Test channel>"; | |||
public bool IsSecureConnection { get; } = false; | |||
@@ -13,6 +13,9 @@ namespace MQTTnet | |||
{ | |||
MqttQualityOfServiceLevel _qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce; | |||
string _topic; | |||
bool? _noLocal; | |||
bool? _retainAsPublished; | |||
MqttRetainHandling? _retainHandling; | |||
public MqttTopicFilterBuilder WithTopic(string topic) | |||
{ | |||
@@ -44,6 +47,24 @@ namespace MQTTnet | |||
return this; | |||
} | |||
public MqttTopicFilterBuilder WithNoLocal(bool? value = true) | |||
{ | |||
_noLocal = value; | |||
return this; | |||
} | |||
public MqttTopicFilterBuilder WithRetainAsPublished(bool? value = true) | |||
{ | |||
_retainAsPublished = value; | |||
return this; | |||
} | |||
public MqttTopicFilterBuilder WithRetainHandling(MqttRetainHandling? value) | |||
{ | |||
_retainHandling = value; | |||
return this; | |||
} | |||
public MqttTopicFilter Build() | |||
{ | |||
if (string.IsNullOrEmpty(_topic)) | |||
@@ -51,7 +72,14 @@ namespace MQTTnet | |||
throw new MqttProtocolViolationException("Topic is not set."); | |||
} | |||
return new MqttTopicFilter { Topic = _topic, QualityOfServiceLevel = _qualityOfServiceLevel }; | |||
return new MqttTopicFilter | |||
{ | |||
Topic = _topic, | |||
QualityOfServiceLevel = _qualityOfServiceLevel, | |||
NoLocal = _noLocal, | |||
RetainAsPublished = _retainAsPublished, | |||
RetainHandling = _retainHandling | |||
}; | |||
} | |||
} | |||
} |
@@ -43,7 +43,7 @@ namespace MQTTnet.Benchmarks | |||
var channel = new TestMqttChannel(_stream); | |||
_channelAdapter = new MqttChannelAdapter(channel, serializer, new MqttNetLogger()); | |||
_channelAdapter = new MqttChannelAdapter(channel, serializer, null, new MqttNetLogger()); | |||
} | |||
[Benchmark] | |||
@@ -9,6 +9,7 @@ using MQTTnet.Channel; | |||
using MQTTnet.Formatter; | |||
using MQTTnet.Formatter.V3; | |||
using BenchmarkDotNet.Jobs; | |||
using MQTTnet.Diagnostics; | |||
namespace MQTTnet.Benchmarks | |||
{ | |||
@@ -48,19 +49,13 @@ namespace MQTTnet.Benchmarks | |||
{ | |||
var channel = new BenchmarkMqttChannel(_serializedPacket); | |||
var fixedHeader = new byte[2]; | |||
var reader = new MqttPacketReader(channel); | |||
var reader = new MqttChannelAdapter(channel, new MqttPacketFormatterAdapter(new MqttPacketWriter()), null, new MqttNetLogger()); | |||
for (var i = 0; i < 10000; i++) | |||
{ | |||
channel.Reset(); | |||
var header = reader.ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
var receivedPacket = new ReceivedMqttPacket( | |||
header.Flags, | |||
new MqttPacketBodyReader(_serializedPacket.Array, _serializedPacket.Count - header.RemainingLength, _serializedPacket.Array.Length), 0); | |||
_serializer.Decode(receivedPacket); | |||
var header = reader.ReceivePacketAsync(CancellationToken.None).GetAwaiter().GetResult(); | |||
} | |||
} | |||
@@ -1,23 +1,24 @@ | |||
using System.IO; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Formatter; | |||
using MQTTnet.Internal; | |||
//using System.IO; | |||
//using System.Threading; | |||
//using System.Threading.Tasks; | |||
//using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
//using MQTTnet.Formatter; | |||
//using MQTTnet.Internal; | |||
namespace MQTTnet.Tests | |||
{ | |||
[TestClass] | |||
public class MqttPacketReader_Tests | |||
{ | |||
[TestMethod] | |||
public async Task MqttPacketReader_EmptyStream() | |||
{ | |||
var fixedHeader = new byte[2]; | |||
var reader = new MqttPacketReader(new TestMqttChannel(new MemoryStream())); | |||
var readResult = await reader.ReadFixedHeaderAsync(fixedHeader, CancellationToken.None); | |||
//namespace MQTTnet.Tests | |||
//{ | |||
// [TestClass] | |||
// public class MqttPacketReader_Tests | |||
// { | |||
// [TestMethod] | |||
// public async Task MqttPacketReader_EmptyStream() | |||
// { | |||
// var fixedHeader = new byte[2]; | |||
// var reader = new MqttPacketReader(new TestMqttChannel(new MemoryStream())); | |||
// var readResult = await reader.ReadFixedHeaderAsync(fixedHeader, CancellationToken.None); | |||
Assert.IsTrue(readResult.ConnectionClosed); | |||
} | |||
} | |||
} | |||
// Assert.IsTrue(readResult.ConnectionClosed); | |||
// } | |||
// } | |||
//} | |||
// TODO: Fix |
@@ -6,6 +6,7 @@ using System.Text; | |||
using System.Threading; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Adapter; | |||
using MQTTnet.Diagnostics; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Formatter; | |||
using MQTTnet.Formatter.V3; | |||
@@ -22,7 +23,7 @@ namespace MQTTnet.Tests | |||
[TestMethod] | |||
public void DetectVersionFromMqttConnectPacket() | |||
{ | |||
var p = new MqttConnectPacket | |||
var packet = new MqttConnectPacket | |||
{ | |||
ClientId = "XYZ", | |||
Password = Encoding.UTF8.GetBytes("PASS"), | |||
@@ -30,17 +31,26 @@ namespace MQTTnet.Tests | |||
KeepAlivePeriod = 123, | |||
CleanSession = true | |||
}; | |||
var adapter = new MqttPacketFormatterAdapter(WriterFactory()); | |||
Assert.AreEqual(MqttProtocolVersion.V310, DeserializeAndDetectVersion(adapter, Serialize(p, MqttProtocolVersion.V310))); | |||
Assert.AreEqual(MqttProtocolVersion.V311, DeserializeAndDetectVersion(adapter, Serialize(p, MqttProtocolVersion.V311))); | |||
Assert.AreEqual(MqttProtocolVersion.V500, DeserializeAndDetectVersion(adapter, Serialize(p, MqttProtocolVersion.V500))); | |||
Assert.AreEqual( | |||
MqttProtocolVersion.V310, | |||
DeserializeAndDetectVersion(new MqttPacketFormatterAdapter(new MqttPacketWriter()), Serialize(packet, MqttProtocolVersion.V310))); | |||
Assert.AreEqual( | |||
MqttProtocolVersion.V311, | |||
DeserializeAndDetectVersion(new MqttPacketFormatterAdapter(new MqttPacketWriter()), Serialize(packet, MqttProtocolVersion.V311))); | |||
Assert.AreEqual( | |||
MqttProtocolVersion.V500, | |||
DeserializeAndDetectVersion(new MqttPacketFormatterAdapter(new MqttPacketWriter()), Serialize(packet, MqttProtocolVersion.V500))); | |||
var adapter = new MqttPacketFormatterAdapter(new MqttPacketWriter()); | |||
var ex = Assert.ThrowsException<MqttProtocolViolationException>(() => DeserializeAndDetectVersion(adapter, WriterFactory().AddMqttHeader(MqttControlPacketType.Connect, new byte[0]))); | |||
Assert.AreEqual("CONNECT packet must have at least 7 bytes.", ex.Message); | |||
ex = Assert.ThrowsException<MqttProtocolViolationException>(() => DeserializeAndDetectVersion(adapter, WriterFactory().AddMqttHeader(MqttControlPacketType.Connect, new byte[7]))); | |||
Assert.AreEqual("Protocol '' not supported.", ex.Message); | |||
ex = Assert.ThrowsException<MqttProtocolViolationException>(() => DeserializeAndDetectVersion(adapter, WriterFactory().AddMqttHeader(MqttControlPacketType.Connect, new byte[] { 255, 255, 0,0,0,0,0 }))); | |||
ex = Assert.ThrowsException<MqttProtocolViolationException>(() => DeserializeAndDetectVersion(adapter, WriterFactory().AddMqttHeader(MqttControlPacketType.Connect, new byte[] { 255, 255, 0, 0, 0, 0, 0 }))); | |||
Assert.AreEqual("Expected at least 65537 bytes but there are only 7 bytes", ex.Message); | |||
} | |||
@@ -205,24 +215,28 @@ namespace MQTTnet.Tests | |||
Payload = payload | |||
}; | |||
var buffer = serializer.Encode(publishPacket); | |||
var testChannel = new TestMqttChannel(new MemoryStream(buffer.Array, buffer.Offset, buffer.Count)); | |||
var header = new MqttPacketReader(testChannel).ReadFixedHeaderAsync( | |||
new byte[2], | |||
CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
var publishPacketCopy = Roundtrip(publishPacket); | |||
//var buffer = serializer.Encode(publishPacket); | |||
//var testChannel = new TestMqttChannel(new MemoryStream(buffer.Array, buffer.Offset, buffer.Count)); | |||
//var header = new MqttPacketReader(testChannel).ReadFixedHeaderAsync( | |||
// new byte[2], | |||
// CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
var eof = buffer.Offset + buffer.Count; | |||
//var eof = buffer.Offset + buffer.Count; | |||
var receivedPacket = new ReceivedMqttPacket( | |||
header.Flags, | |||
new MqttPacketBodyReader(buffer.Array, eof - header.RemainingLength, buffer.Count + buffer.Offset), | |||
0); | |||
//var receivedPacket = new ReceivedMqttPacket( | |||
// header.Flags, | |||
// new MqttPacketBodyReader(buffer.Array, eof - header.RemainingLength, buffer.Count + buffer.Offset), | |||
// 0); | |||
var packet = (MqttPublishPacket)serializer.Decode(receivedPacket); | |||
//var packet = (MqttPublishPacket)serializer.Decode(receivedPacket); | |||
Assert.AreEqual(publishPacket.Topic, packet.Topic); | |||
Assert.IsTrue(publishPacket.Payload.SequenceEqual(packet.Payload)); | |||
Assert.AreEqual(publishPacket.Topic, publishPacketCopy.Topic); | |||
Assert.IsTrue(publishPacket.Payload.SequenceEqual(publishPacketCopy.Payload)); | |||
} | |||
[TestMethod] | |||
@@ -262,7 +276,7 @@ namespace MQTTnet.Tests | |||
[TestMethod] | |||
public void SerializeV500_MqttPublishPacket() | |||
{ | |||
var prop = new MqttPublishPacketProperties {UserProperties = new List<MqttUserProperty>()}; | |||
var prop = new MqttPublishPacketProperties { UserProperties = new List<MqttUserProperty>() }; | |||
prop.ResponseTopic = "/Response"; | |||
@@ -581,15 +595,14 @@ namespace MQTTnet.Tests | |||
DeserializeAndCompare(p, "sAIAew=="); | |||
} | |||
private void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
{ | |||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Serialize(packet, protocolVersion))); | |||
} | |||
private byte[] Serialize(MqttBasePacket packet, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
byte[] Serialize(MqttBasePacket packet, MqttProtocolVersion protocolVersion) | |||
{ | |||
var serializer = MqttPacketFormatterAdapter.GetMqttPacketFormatter(protocolVersion, WriterFactory()); | |||
return Join(serializer.Encode(packet)); | |||
return MqttPacketFormatterAdapter.GetMqttPacketFormatter(protocolVersion, WriterFactory()).Encode(packet).ToArray(); | |||
} | |||
protected virtual IMqttPacketWriter WriterFactory() | |||
@@ -602,83 +615,92 @@ namespace MQTTnet.Tests | |||
return new MqttPacketBodyReader(data, 0, data.Length); | |||
} | |||
private void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
{ | |||
var writer = WriterFactory(); | |||
var serializer = MqttPacketFormatterAdapter.GetMqttPacketFormatter(protocolVersion, writer); | |||
var buffer1 = serializer.Encode(packet); | |||
using (var headerStream = new MemoryStream(Join(buffer1))) | |||
using (var headerStream = new MemoryStream(buffer1.ToArray())) | |||
{ | |||
var channel = new TestMqttChannel(headerStream); | |||
var fixedHeader = new byte[2]; | |||
var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
var adapter = new MqttChannelAdapter(channel, new MqttPacketFormatterAdapter(protocolVersion, writer), null, new MqttNetLogger()); | |||
var receivedPacket = adapter.ReceivePacketAsync(CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | |||
{ | |||
var reader = ReaderFactory(bodyStream.ToArray()); | |||
var deserializedPacket = serializer.Decode(new ReceivedMqttPacket(header.Flags, reader, 0)); | |||
var buffer2 = serializer.Encode(deserializedPacket); | |||
var buffer2 = serializer.Encode(receivedPacket); | |||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2))); | |||
} | |||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(buffer2.ToArray())); | |||
//adapter.ReceivePacketAsync(CancellationToken.None); | |||
//var fixedHeader = new byte[2]; | |||
//var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
//using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | |||
//{ | |||
// var reader = ReaderFactory(bodyStream.ToArray()); | |||
// var deserializedPacket = serializer.Decode(new ReceivedMqttPacket(header.Flags, reader, 0)); | |||
// var buffer2 = serializer.Encode(deserializedPacket); | |||
// Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2))); | |||
//} | |||
} | |||
} | |||
private T Roundtrip<T>(T packet, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
where T : MqttBasePacket | |||
TPacket Roundtrip<TPacket>(TPacket packet, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
where TPacket : MqttBasePacket | |||
{ | |||
var writer = WriterFactory(); | |||
var serializer = MqttPacketFormatterAdapter.GetMqttPacketFormatter(protocolVersion, writer); | |||
var buffer = serializer.Encode(packet); | |||
var buffer1 = serializer.Encode(packet); | |||
var channel = new TestMqttChannel(buffer.ToArray()); | |||
var adapter = new MqttChannelAdapter(channel, new MqttPacketFormatterAdapter(protocolVersion, writer), null, new MqttNetLogger()); | |||
return (TPacket)adapter.ReceivePacketAsync(CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var headerStream = new MemoryStream(Join(buffer1))) | |||
{ | |||
var channel = new TestMqttChannel(headerStream); | |||
var fixedHeader = new byte[2]; | |||
//using (var headerStream = new MemoryStream(buffer1.ToArray())) | |||
//{ | |||
var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, (int)header.RemainingLength)) | |||
{ | |||
var reader = ReaderFactory(bodyStream.ToArray()); | |||
return (T)serializer.Decode(new ReceivedMqttPacket(header.Flags, reader, 0)); | |||
} | |||
} | |||
} | |||
private MqttProtocolVersion DeserializeAndDetectVersion(MqttPacketFormatterAdapter adapter, byte[] buffer) | |||
{ | |||
using (var headerStream = new MemoryStream(buffer)) | |||
{ | |||
var channel = new TestMqttChannel(headerStream); | |||
var fixedHeader = new byte[2]; | |||
var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
// //var fixedHeader = new byte[2]; | |||
using (var bodyStream = new MemoryStream(buffer, (int)headerStream.Position, (int)header.RemainingLength)) | |||
{ | |||
var reader = ReaderFactory(bodyStream.ToArray()); | |||
var packet = new ReceivedMqttPacket(header.Flags, reader, 0); | |||
adapter.DetectProtocolVersion(packet); | |||
return adapter.ProtocolVersion; | |||
} | |||
} | |||
// //var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
// //using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, (int)header.RemainingLength)) | |||
// //{ | |||
// // var reader = ReaderFactory(bodyStream.ToArray()); | |||
// // return (T)serializer.Decode(new ReceivedMqttPacket(header.Flags, reader, 0)); | |||
// //} | |||
//} | |||
} | |||
private static byte[] Join(params ArraySegment<byte>[] chunks) | |||
MqttProtocolVersion DeserializeAndDetectVersion(MqttPacketFormatterAdapter packetFormatterAdapter, byte[] buffer) | |||
{ | |||
var buffer = new MemoryStream(); | |||
foreach (var chunk in chunks) | |||
{ | |||
buffer.Write(chunk.Array, chunk.Offset, chunk.Count); | |||
} | |||
var channel = new TestMqttChannel(buffer); | |||
var adapter = new MqttChannelAdapter(channel, packetFormatterAdapter, null, new MqttNetLogger()); | |||
adapter.ReceivePacketAsync(CancellationToken.None).GetAwaiter().GetResult(); | |||
return packetFormatterAdapter.ProtocolVersion; | |||
//using (var headerStream = new MemoryStream(buffer)) | |||
//{ | |||
// //var fixedHeader = new byte[2]; | |||
// //var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; | |||
return buffer.ToArray(); | |||
// //using (var bodyStream = new MemoryStream(buffer, (int)headerStream.Position, (int)header.RemainingLength)) | |||
// //{ | |||
// // var reader = ReaderFactory(bodyStream.ToArray()); | |||
// // var packet = new ReceivedMqttPacket(header.Flags, reader, 0); | |||
// // packetFormatterAdapter.DetectProtocolVersion(packet); | |||
// // return adapter.ProtocolVersion; | |||
// //} | |||
//} | |||
} | |||
} | |||
} |