diff --git a/Source/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs b/Source/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs index a83fcef..e8f570c 100644 --- a/Source/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs +++ b/Source/MQTTnet.AspnetCore/Client/MqttClientConnectionContextFactory.cs @@ -21,7 +21,10 @@ namespace MQTTnet.AspNetCore.Client { var endpoint = new DnsEndPoint(tcpOptions.Server, tcpOptions.GetPort()); var tcpConnection = new TcpConnection(endpoint); - return new MqttConnectionContext(new MqttPacketFormatterAdapter(options.ProtocolVersion), tcpConnection); + + var writer = new SpanBasedMqttPacketWriter(); + var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, writer); + return new MqttConnectionContext(formatter, tcpConnection); } default: { diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index df67eeb..400aa07 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -123,14 +123,19 @@ namespace MQTTnet.AspNetCore public async Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout, CancellationToken cancellationToken) { - var buffer = PacketFormatterAdapter.Encode(packet); - var msg = buffer.AsMemory(); - var output = _output; + var formatter = PacketFormatterAdapter; + await _writerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); try { - await output.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); + var buffer = formatter.Encode(packet); + var msg = buffer.AsMemory(); + var output = _output; + msg.CopyTo(output.GetMemory(msg.Length)); + PacketFormatterAdapter.FreeBuffer(); + output.Advance(msg.Length); + await output.FlushAsync().ConfigureAwait(false); } finally { diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs index ec2528f..f3cb91a 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs @@ -21,7 +21,9 @@ namespace MQTTnet.AspNetCore transferFormatFeature.ActiveFormat = TransferFormat.Binary; } - using (var adapter = new MqttConnectionContext(new MqttPacketFormatterAdapter(), connection)) + var writer = new SpanBasedMqttPacketWriter(); + var formatter = new MqttPacketFormatterAdapter(writer); + using (var adapter = new MqttConnectionContext(formatter, connection)) { var args = new MqttServerAdapterClientAcceptedEventArgs(adapter); ClientAcceptedHandler?.Invoke(args); diff --git a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs b/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs index 3e5e607..1cd512b 100644 --- a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs +++ b/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs @@ -43,8 +43,10 @@ namespace MQTTnet.AspNetCore var isSecureConnection = clientCertificate != null; clientCertificate?.Dispose(); + var writer = new SpanBasedMqttPacketWriter(); + var formatter = new MqttPacketFormatterAdapter(writer); var channel = new MqttWebSocketChannel(webSocket, endpoint, isSecureConnection); - var channelAdapter = new MqttChannelAdapter(channel, new MqttPacketFormatterAdapter(), _logger.CreateChildLogger(nameof(MqttWebSocketServerAdapter))); + var channelAdapter = new MqttChannelAdapter(channel, formatter, new MqttNetLogger().CreateChildLogger(nameof(MqttWebSocketServerAdapter))); var eventArgs = new MqttServerAdapterClientAcceptedEventArgs(channelAdapter); ClientAcceptedHandler?.Invoke(eventArgs); diff --git a/Source/MQTTnet.AspnetCore/SpanBasedMqttPacketWriter.cs b/Source/MQTTnet.AspnetCore/SpanBasedMqttPacketWriter.cs new file mode 100644 index 0000000..7ad6e00 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/SpanBasedMqttPacketWriter.cs @@ -0,0 +1,147 @@ +using MQTTnet.Formatter; +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Text; + +namespace MQTTnet.AspNetCore +{ + public class SpanBasedMqttPacketWriter : IMqttPacketWriter + { + private readonly ArrayPool _pool; + + public SpanBasedMqttPacketWriter() + { + _pool = ArrayPool.Create(); + + } + + private byte[] _buffer; + private int _position; + + public int Length { get; set; } + + public void FreeBuffer() + { + _pool.Return(_buffer); + } + + public byte[] GetBuffer() + { + return _buffer; + } + + public void Reset(int v) + { + _buffer = _pool.Rent(1500); + Length = v; + _position = v; + } + + public void Seek(int v) + { + _position = v; + } + + public void Write(byte value) + { + GrowIfNeeded(1); + _buffer[_position] = value; + Commit(1); + } + + public void Write(ushort value) + { + GrowIfNeeded(2); + + BinaryPrimitives.WriteUInt16BigEndian(_buffer.AsSpan(_position), value); + Commit(2); + } + + public void Write(IMqttPacketWriter propertyWriter) + { + if (propertyWriter is SpanBasedMqttPacketWriter writer) + { + GrowIfNeeded(1); + } + + throw new InvalidOperationException($"{nameof(propertyWriter)} must be of type {typeof(SpanBasedMqttPacketWriter).Name}"); + } + + public void Write(byte[] payload, int start, int length) + { + GrowIfNeeded(length); + + payload.AsSpan(start, length).CopyTo(_buffer.AsSpan(_position)); + Commit(length); + } + + public void WriteVariableLengthInteger(uint value) + { + GrowIfNeeded(4); + + var x = value; + do + { + var encodedByte = x % 128; + x = x / 128; + if (x > 0) + { + encodedByte = encodedByte | 128; + } + + _buffer[_position] = (byte)encodedByte; + Commit(1); + } while (x > 0); + } + + public void WriteWithLengthPrefix(string value) + { + var bytesLength = Encoding.UTF8.GetByteCount(value ?? string.Empty); + GrowIfNeeded(bytesLength + 2); + + Write((ushort)bytesLength); + Encoding.UTF8.GetBytes(value ?? string.Empty, 0, value?.Length ?? 0, _buffer, _position); + Commit(bytesLength); + } + + public void WriteWithLengthPrefix(byte[] payload) + { + GrowIfNeeded(payload.Length + 2); + + Write((ushort)payload.Length); + payload.CopyTo(_buffer, _position); + Commit(payload.Length); + } + + private void Commit(int count) + { + if (_position == Length) + { + Length += count; + } + + _position += count; + } + + private void GrowIfNeeded(int requiredAdditional) + { + var requiredTotal = _position + requiredAdditional; + if (_buffer.Length >= requiredTotal) + { + return; + } + + var newBufferLength = _buffer.Length; + while (newBufferLength < requiredTotal) + { + newBufferLength *= 2; + } + + var newBuffer = _pool.Rent(newBufferLength); + Array.Copy(_buffer, newBuffer, _buffer.Length); + _pool.Return(_buffer); + _buffer = newBuffer; + } + } +} diff --git a/Source/MQTTnet/Formatter/IMqttPacketWriter.cs b/Source/MQTTnet/Formatter/IMqttPacketWriter.cs new file mode 100644 index 0000000..cce98cf --- /dev/null +++ b/Source/MQTTnet/Formatter/IMqttPacketWriter.cs @@ -0,0 +1,20 @@ +namespace MQTTnet.Formatter +{ + public interface IMqttPacketWriter + { + int Length { get; } + + void WriteWithLengthPrefix(string value); + void Write(byte returnCode); + void WriteWithLengthPrefix(byte[] payload); + void Write(ushort keepAlivePeriod); + + void Write(IMqttPacketWriter propertyWriter); + void WriteVariableLengthInteger(uint length); + void Write(byte[] payload, int v, int length); + void Reset(int v); + void Seek(int v); + void FreeBuffer(); + byte[] GetBuffer(); + } +} diff --git a/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs b/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs index 47f80db..0b44228 100644 --- a/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs +++ b/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs @@ -12,14 +12,26 @@ namespace MQTTnet.Formatter private IMqttPacketFormatter _formatter; public MqttPacketFormatterAdapter() + : this(new MqttPacketWriter()) { } - + public MqttPacketFormatterAdapter(MqttProtocolVersion protocolVersion) + : this(protocolVersion, new MqttPacketWriter()) + { + } + + public MqttPacketFormatterAdapter(MqttProtocolVersion protocolVersion, IMqttPacketWriter writer) + : this(writer) { UseProtocolVersion(protocolVersion); } + public MqttPacketFormatterAdapter(IMqttPacketWriter writer) + { + Writer = writer; + } + public MqttProtocolVersion ProtocolVersion { get; private set; } = MqttProtocolVersion.Unknown; public IMqttDataConverter DataConverter @@ -31,6 +43,8 @@ namespace MQTTnet.Formatter return _formatter.DataConverter; } } + + public IMqttPacketWriter Writer { get; } public ArraySegment Encode(MqttBasePacket packet) { @@ -80,20 +94,20 @@ namespace MQTTnet.Formatter { case MqttProtocolVersion.V500: { - _formatter = new MqttV500PacketFormatter(); + _formatter = new MqttV500PacketFormatter(Writer); break; } case MqttProtocolVersion.V311: { - _formatter = new MqttV311PacketFormatter(); + _formatter = new MqttV311PacketFormatter(Writer); break; } case MqttProtocolVersion.V310: { - _formatter = new MqttV310PacketFormatter(); + _formatter = new MqttV310PacketFormatter(Writer); break; } diff --git a/Source/MQTTnet/Formatter/MqttPacketWriter.cs b/Source/MQTTnet/Formatter/MqttPacketWriter.cs index 1cedd08..d11cc54 100644 --- a/Source/MQTTnet/Formatter/MqttPacketWriter.cs +++ b/Source/MQTTnet/Formatter/MqttPacketWriter.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Formatter /// same as for the original MemoryStream in .net. Also this implementation allows accessing the internal /// buffer for all platforms and .net framework versions (which is not available at the regular MemoryStream). /// - public class MqttPacketWriter + public class MqttPacketWriter : IMqttPacketWriter { private static readonly ArraySegment ZeroVariableLengthIntegerArray = new ArraySegment(new byte[1], 0, 1); private static readonly ArraySegment ZeroTwoByteIntegerArray = new ArraySegment(new byte[2], 0, 2); @@ -33,6 +33,19 @@ namespace MQTTnet.Formatter return (byte)fixedHeader; } + public static int GetLengthOfVariableInteger(uint value) + { + var result = 0; + var x = value; + do + { + x = x / 128; + result++; + } while (x > 0); + + return result; + } + public static ArraySegment EncodeVariableLengthInteger(uint value) { if (value == 0) @@ -129,16 +142,21 @@ namespace MQTTnet.Formatter IncreasePosition(count); } - public void Write(MqttPacketWriter propertyWriter) + public void Write(IMqttPacketWriter propertyWriter) { if (propertyWriter == null) throw new ArgumentNullException(nameof(propertyWriter)); - if (propertyWriter.Length == 0) + if (propertyWriter is MqttPacketWriter writer) { - return; + if (writer.Length == 0) + { + return; + } + + Write(writer._buffer, 0, writer.Length); } - Write(propertyWriter._buffer, 0, propertyWriter.Length); + throw new InvalidOperationException($"{nameof(propertyWriter)} must be of type {typeof(MqttPacketWriter).Name}"); } public void Reset(int length) diff --git a/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs b/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs index d3e6b41..70af10e 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs @@ -10,8 +10,20 @@ namespace MQTTnet.Formatter.V3 public class MqttV310PacketFormatter : IMqttPacketFormatter { private const int FixedHeaderSize = 1; + + private readonly IMqttPacketWriter _packetWriter; + + public MqttV310PacketFormatter() + : this(new MqttPacketWriter()) + { + + } + + public MqttV310PacketFormatter(IMqttPacketWriter packetWriter) + { + _packetWriter = packetWriter; + } - private readonly MqttPacketWriter _packetWriter = new MqttPacketWriter(); public IMqttDataConverter DataConverter { get; } = new MqttV310DataConverter(); @@ -26,17 +38,18 @@ namespace MQTTnet.Formatter.V3 var fixedHeader = EncodePacket(packet, _packetWriter); var remainingLength = (uint)(_packetWriter.Length - 5); - var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength); + var remainingLengthSize = MqttPacketWriter.GetLengthOfVariableInteger(remainingLength); - var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; + var headerSize = FixedHeaderSize + remainingLengthSize; var headerOffset = 5 - headerSize; // Position cursor on correct offset on beginning of array (has leading 0x0) _packetWriter.Seek(headerOffset); _packetWriter.Write(fixedHeader); - _packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); + _packetWriter.WriteVariableLengthInteger(remainingLength); var buffer = _packetWriter.GetBuffer(); + return new ArraySegment(buffer, headerOffset, _packetWriter.Length - headerOffset); } @@ -76,7 +89,7 @@ namespace MQTTnet.Formatter.V3 _packetWriter.FreeBuffer(); } - private byte EncodePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) + private byte EncodePacket(MqttBasePacket packet, IMqttPacketWriter packetWriter) { switch (packet) { @@ -331,7 +344,7 @@ namespace MQTTnet.Formatter.V3 } } - protected virtual byte EncodeConnectPacket(MqttConnectPacket packet, MqttPacketWriter packetWriter) + protected virtual byte EncodeConnectPacket(MqttConnectPacket packet, IMqttPacketWriter packetWriter) { ValidateConnectPacket(packet); @@ -393,7 +406,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - protected virtual byte EncodeConnAckPacket(MqttConnAckPacket packet, MqttPacketWriter packetWriter) + protected virtual byte EncodeConnAckPacket(MqttConnAckPacket packet, IMqttPacketWriter packetWriter) { packetWriter.Write(0); // Reserved. packetWriter.Write((byte)packet.ReturnCode.Value); @@ -401,7 +414,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); } - private static byte EncodePubRelPacket(MqttPubRelPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubRelPacket(MqttPubRelPacket packet, IMqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { @@ -413,7 +426,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } - private static byte EncodePublishPacket(MqttPublishPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePublishPacket(MqttPublishPacket packet, IMqttPacketWriter packetWriter) { ValidatePublishPacket(packet); @@ -458,7 +471,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); } - private static byte EncodePubAckPacket(MqttPubAckPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubAckPacket(MqttPubAckPacket packet, IMqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { @@ -470,7 +483,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } - private static byte EncodePubRecPacket(MqttPubRecPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubRecPacket(MqttPubRecPacket packet, IMqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { @@ -482,7 +495,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } - private static byte EncodePubCompPacket(MqttPubCompPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubCompPacket(MqttPubCompPacket packet, IMqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { @@ -494,7 +507,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } - private static byte EncodeSubscribePacket(MqttSubscribePacket packet, MqttPacketWriter packetWriter) + private static byte EncodeSubscribePacket(MqttSubscribePacket packet, IMqttPacketWriter packetWriter) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); @@ -517,7 +530,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); } - private static byte EncodeSubAckPacket(MqttSubAckPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeSubAckPacket(MqttSubAckPacket packet, IMqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { @@ -537,7 +550,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); } - private static byte EncodeUnsubscribePacket(MqttUnsubscribePacket packet, MqttPacketWriter packetWriter) + private static byte EncodeUnsubscribePacket(MqttUnsubscribePacket packet, IMqttPacketWriter packetWriter) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); @@ -559,7 +572,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static byte EncodeUnsubAckPacket(MqttUnsubAckPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeUnsubAckPacket(MqttUnsubAckPacket packet, IMqttPacketWriter packetWriter) { if (!packet.PacketIdentifier.HasValue) { diff --git a/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs b/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs index 44725d4..a846cd0 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs @@ -6,7 +6,17 @@ namespace MQTTnet.Formatter.V3 { public class MqttV311PacketFormatter : MqttV310PacketFormatter { - protected override byte EncodeConnectPacket(MqttConnectPacket packet, MqttPacketWriter packetWriter) + public MqttV311PacketFormatter() + : base() + { + } + + public MqttV311PacketFormatter(IMqttPacketWriter packetWriter) + : base(packetWriter) + { + } + + protected override byte EncodeConnectPacket(MqttConnectPacket packet, IMqttPacketWriter packetWriter) { ValidateConnectPacket(packet); @@ -68,7 +78,7 @@ namespace MQTTnet.Formatter.V3 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - protected override byte EncodeConnAckPacket(MqttConnAckPacket packet, MqttPacketWriter packetWriter) + protected override byte EncodeConnAckPacket(MqttConnAckPacket packet, IMqttPacketWriter packetWriter) { byte connectAcknowledgeFlags = 0x0; if (packet.IsSessionPresent) diff --git a/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs b/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs index 83f003b..5ccd030 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs @@ -8,7 +8,19 @@ namespace MQTTnet.Formatter.V5 { public class MqttV500PacketEncoder { - private readonly MqttPacketWriter _packetWriter = new MqttPacketWriter(); + private readonly IMqttPacketWriter _packetWriter; + + public MqttV500PacketEncoder() + : this(new MqttPacketWriter()) + { + + } + + public MqttV500PacketEncoder(IMqttPacketWriter packetWriter) + { + _packetWriter = packetWriter; + } + public ArraySegment Encode(MqttBasePacket packet) { @@ -21,15 +33,15 @@ namespace MQTTnet.Formatter.V5 var fixedHeader = EncodePacket(packet, _packetWriter); var remainingLength = (uint)(_packetWriter.Length - 5); - var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength); - - var headerSize = 1 + remainingLengthBuffer.Count; + var remainingLengthSize = MqttPacketWriter.GetLengthOfVariableInteger(remainingLength); + + var headerSize = 1 + remainingLengthSize; var headerOffset = 5 - headerSize; // Position cursor on correct offset on beginning of array (has leading 0x0) _packetWriter.Seek(headerOffset); _packetWriter.Write(fixedHeader); - _packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); + _packetWriter.WriteVariableLengthInteger(remainingLength); var buffer = _packetWriter.GetBuffer(); return new ArraySegment(buffer, headerOffset, _packetWriter.Length - headerOffset); @@ -40,7 +52,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.FreeBuffer(); } - private static byte EncodePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) + private static byte EncodePacket(MqttBasePacket packet, IMqttPacketWriter packetWriter) { switch (packet) { @@ -64,7 +76,7 @@ namespace MQTTnet.Formatter.V5 } } - private static byte EncodeConnectPacket(MqttConnectPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeConnectPacket(MqttConnectPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); @@ -150,7 +162,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - private static byte EncodeConnAckPacket(MqttConnAckPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeConnAckPacket(MqttConnAckPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); @@ -195,7 +207,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); } - private static byte EncodePublishPacket(MqttPublishPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePublishPacket(MqttPublishPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); @@ -261,7 +273,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); } - private static byte EncodePubAckPacket(MqttPubAckPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubAckPacket(MqttPubAckPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); @@ -294,7 +306,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } - private static byte EncodePubRecPacket(MqttPubRecPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubRecPacket(MqttPubRecPacket packet, IMqttPacketWriter packetWriter) { ThrowIfPacketIdentifierIsInvalid(packet); @@ -322,7 +334,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } - private static byte EncodePubRelPacket(MqttPubRelPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubRelPacket(MqttPubRelPacket packet, IMqttPacketWriter packetWriter) { ThrowIfPacketIdentifierIsInvalid(packet); @@ -350,7 +362,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } - private static byte EncodePubCompPacket(MqttPubCompPacket packet, MqttPacketWriter packetWriter) + private static byte EncodePubCompPacket(MqttPubCompPacket packet, IMqttPacketWriter packetWriter) { ThrowIfPacketIdentifierIsInvalid(packet); @@ -378,7 +390,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } - private static byte EncodeSubscribePacket(MqttSubscribePacket packet, MqttPacketWriter packetWriter) + private static byte EncodeSubscribePacket(MqttSubscribePacket packet, IMqttPacketWriter packetWriter) { if (packet.TopicFilters?.Any() != true) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); @@ -425,7 +437,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); } - private static byte EncodeSubAckPacket(MqttSubAckPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeSubAckPacket(MqttSubAckPacket packet, IMqttPacketWriter packetWriter) { if (packet.ReasonCodes?.Any() != true) throw new MqttProtocolViolationException("At least one reason code must be set[MQTT - 3.8.3 - 3]."); @@ -450,7 +462,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); } - private static byte EncodeUnsubscribePacket(MqttUnsubscribePacket packet, MqttPacketWriter packetWriter) + private static byte EncodeUnsubscribePacket(MqttUnsubscribePacket packet, IMqttPacketWriter packetWriter) { if (packet.TopicFilters?.Any() != true) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); @@ -474,7 +486,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static byte EncodeUnsubAckPacket(MqttUnsubAckPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeUnsubAckPacket(MqttUnsubAckPacket packet, IMqttPacketWriter packetWriter) { if (packet.ReasonCodes?.Any() != true) throw new MqttProtocolViolationException("At least one reason code must be set[MQTT - 3.8.3 - 3]."); @@ -499,7 +511,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } - private static byte EncodeDisconnectPacket(MqttDisconnectPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeDisconnectPacket(MqttDisconnectPacket packet, IMqttPacketWriter packetWriter) { if (!packet.ReasonCode.HasValue) { @@ -532,7 +544,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PingResp); } - private static byte EncodeAuthPacket(MqttAuthPacket packet, MqttPacketWriter packetWriter) + private static byte EncodeAuthPacket(MqttAuthPacket packet, IMqttPacketWriter packetWriter) { packetWriter.Write((byte)packet.ReasonCode); diff --git a/Source/MQTTnet/Formatter/V5/MqttV500PacketFormatter.cs b/Source/MQTTnet/Formatter/V5/MqttV500PacketFormatter.cs index ab52885..9a52a98 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500PacketFormatter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500PacketFormatter.cs @@ -6,9 +6,19 @@ namespace MQTTnet.Formatter.V5 { public class MqttV500PacketFormatter : IMqttPacketFormatter { - private readonly MqttV500PacketEncoder _encoder = new MqttV500PacketEncoder(); + private readonly MqttV500PacketEncoder _encoder; private readonly MqttV500PacketDecoder _decoder = new MqttV500PacketDecoder(); + public MqttV500PacketFormatter() + { + _encoder = new MqttV500PacketEncoder(); + } + + public MqttV500PacketFormatter(IMqttPacketWriter writer) + { + _encoder = new MqttV500PacketEncoder(writer); + } + public IMqttDataConverter DataConverter { get; } = new MqttV500DataConverter(); public ArraySegment Encode(MqttBasePacket mqttPacket) diff --git a/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs b/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs index e10790c..9e5fe39 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs @@ -63,7 +63,7 @@ namespace MQTTnet.Formatter.V5 Write(MqttPropertyId.AuthenticationMethod, value); } - public void WriteToPacket(MqttPacketWriter packetWriter) + public void WriteToPacket(IMqttPacketWriter packetWriter) { if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); diff --git a/Tests/MQTTnet.AspNetCore.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.AspNetCore.Tests/MqttPacketSerializerTests.cs index 65625d3..2d389ab 100644 --- a/Tests/MQTTnet.AspNetCore.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.AspNetCore.Tests/MqttPacketSerializerTests.cs @@ -13,5 +13,10 @@ namespace MQTTnet.AspNetCore.Tests result.SetBuffer(data); return result; } + + protected override IMqttPacketWriter WriterFactory() + { + return new SpanBasedMqttPacketWriter(); + } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializer_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializer_Tests.cs index cb1b069..e4dd135 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializer_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializer_Tests.cs @@ -528,6 +528,11 @@ namespace MQTTnet.Tests Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(data))); } + protected virtual IMqttPacketWriter WriterFactory() + { + return new MqttPacketWriter(); + } + protected virtual IMqttPacketBodyReader ReaderFactory(byte[] data) { return new MqttPacketBodyReader(data, 0, data.Length); @@ -535,14 +540,16 @@ namespace MQTTnet.Tests private void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) { + var writer = WriterFactory(); + IMqttPacketFormatter serializer; if (protocolVersion == MqttProtocolVersion.V311) { - serializer = new MqttV311PacketFormatter(); + serializer = new MqttV311PacketFormatter(writer); } else if (protocolVersion == MqttProtocolVersion.V310) { - serializer = new MqttV310PacketFormatter(); + serializer = new MqttV310PacketFormatter(writer); } else { @@ -559,7 +566,8 @@ namespace MQTTnet.Tests using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) { - var deserializedPacket = serializer.Decode(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(bodyStream.ToArray(), 0, (int)bodyStream.Length), 0)); + var reader = ReaderFactory(bodyStream.ToArray()); + var deserializedPacket = serializer.Decode(new ReceivedMqttPacket(header.Flags, reader, 0)); var buffer2 = serializer.Encode(deserializedPacket); Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(buffer2))); @@ -567,17 +575,19 @@ namespace MQTTnet.Tests } } - private static T Roundtrip(T packet, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + private T Roundtrip(T packet, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) where T : MqttBasePacket { + var writer = WriterFactory(); + IMqttPacketFormatter serializer; if (protocolVersion == MqttProtocolVersion.V311) { - serializer = new MqttV311PacketFormatter(); + serializer = new MqttV311PacketFormatter(writer); } else if (protocolVersion == MqttProtocolVersion.V310) { - serializer = new MqttV310PacketFormatter(); + serializer = new MqttV310PacketFormatter(writer); } else { @@ -595,7 +605,8 @@ namespace MQTTnet.Tests using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, (int)header.RemainingLength)) { - return (T)serializer.Decode(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(bodyStream.ToArray(), 0, (int)bodyStream.Length), 0)); + var reader = ReaderFactory(bodyStream.ToArray()); + return (T)serializer.Decode(new ReceivedMqttPacket(header.Flags, reader, 0)); } } }