From 644bcfba2777c9817e5d2a837d91f8863f98d7b1 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Sun, 10 Sep 2017 15:41:39 +0200 Subject: [PATCH] - make MqttPacketReader and Writer a specialized version of their Binary counterparts - improved message framing. as basis to improve buffering - share more common code for the writer part --- MQTTnet.Core/Packets/MqttPacketHeader.cs | 13 + MQTTnet.Core/Serializer/MqttPacketReader.cs | 139 +++--- .../Serializer/MqttPacketSerializer.cs | 420 +++++++++--------- MQTTnet.Core/Serializer/MqttPacketWriter.cs | 72 +-- 4 files changed, 287 insertions(+), 357 deletions(-) create mode 100644 MQTTnet.Core/Packets/MqttPacketHeader.cs diff --git a/MQTTnet.Core/Packets/MqttPacketHeader.cs b/MQTTnet.Core/Packets/MqttPacketHeader.cs new file mode 100644 index 0000000..41646df --- /dev/null +++ b/MQTTnet.Core/Packets/MqttPacketHeader.cs @@ -0,0 +1,13 @@ +using MQTTnet.Core.Protocol; + +namespace MQTTnet.Core.Packets +{ + public class MqttPacketHeader + { + public MqttControlPacketType ControlPacketType { get; set; } + + public byte FixedHeader { get; set; } + + public int BodyLength { get; set; } + } +} diff --git a/MQTTnet.Core/Serializer/MqttPacketReader.cs b/MQTTnet.Core/Serializer/MqttPacketReader.cs index f64c0c9..af27858 100644 --- a/MQTTnet.Core/Serializer/MqttPacketReader.cs +++ b/MQTTnet.Core/Serializer/MqttPacketReader.cs @@ -2,55 +2,28 @@ using System.IO; using System.Text; using System.Threading.Tasks; -using MQTTnet.Core.Channel; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Protocol; +using MQTTnet.Core.Channel; +using MQTTnet.Core.Packets; namespace MQTTnet.Core.Serializer { - public sealed class MqttPacketReader : IDisposable + public sealed class MqttPacketReader : BinaryReader { - private readonly MemoryStream _remainingData = new MemoryStream(1024); - private readonly IMqttCommunicationChannel _source; + private readonly MqttPacketHeader _header; - private int _remainingLength; - - public MqttPacketReader(IMqttCommunicationChannel source) + public MqttPacketReader(Stream stream, MqttPacketHeader header) + : base(stream) { - _source = source ?? throw new ArgumentNullException(nameof(source)); + _header = header; } - - public MqttControlPacketType ControlPacketType { get; private set; } - - public byte FixedHeader { get; private set; } - - public bool EndOfRemainingData => _remainingData.Position == _remainingData.Length; - - public async Task ReadToEndAsync() + + public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; + + public override ushort ReadUInt16() { - await ReadFixedHeaderAsync().ConfigureAwait(false); - await ReadRemainingLengthAsync().ConfigureAwait(false); - - if (_remainingLength == 0) - { - return; - } - - var buffer = new byte[_remainingLength]; - await ReadFromSourceAsync(buffer).ConfigureAwait(false); - - _remainingData.Write(buffer, 0, buffer.Length); - _remainingData.Position = 0; - } - - public byte ReadRemainingDataByte() - { - return ReadRemainingData(1)[0]; - } - - public ushort ReadRemainingDataUShort() - { - var buffer = ReadRemainingData(2); + var buffer = ReadBytes(2); var temp = buffer[0]; buffer[0] = buffer[1]; @@ -59,31 +32,59 @@ namespace MQTTnet.Core.Serializer return BitConverter.ToUInt16(buffer, 0); } - public string ReadRemainingDataStringWithLengthPrefix() + public string ReadStringWithLengthPrefix() { - var buffer = ReadRemainingDataWithLengthPrefix(); + var buffer = ReadWithLengthPrefix(); return Encoding.UTF8.GetString(buffer, 0, buffer.Length); } - public byte[] ReadRemainingDataWithLengthPrefix() + public byte[] ReadWithLengthPrefix() { - var length = ReadRemainingDataUShort(); - return ReadRemainingData(length); + var length = ReadUInt16(); + return ReadBytes(length); } public byte[] ReadRemainingData() { - return ReadRemainingData(_remainingLength - (int)_remainingData.Position); + return ReadBytes(_header.BodyLength - (int)BaseStream.Position); + } + + public static async Task ReadHeaderFromSourceAsync(IMqttCommunicationChannel source) + { + var fixedHeader = await ReadStreamByteAsync(source).ConfigureAwait(false); + var byteReader = new ByteReader(fixedHeader); + byteReader.Read(4); + var controlPacketType = (MqttControlPacketType)byteReader.Read(4); + var bodyLength = await ReadBodyLengthFromSourceAsync(source).ConfigureAwait(false); + + return new MqttPacketHeader() + { + FixedHeader = fixedHeader, + ControlPacketType = controlPacketType, + BodyLength = bodyLength + }; } - public byte[] ReadRemainingData(int length) + private static async Task ReadStreamByteAsync(IMqttCommunicationChannel source) { - var buffer = new byte[length]; - _remainingData.Read(buffer, 0, buffer.Length); - return buffer; + var buffer = new byte[1]; + await ReadFromSourceAsync(source, buffer).ConfigureAwait(false); + return buffer[0]; } - private async Task ReadRemainingLengthAsync() + public static async Task ReadFromSourceAsync(IMqttCommunicationChannel source, byte[] buffer) + { + try + { + await source.ReadAsync(buffer); + } + catch (Exception exception) + { + throw new MqttCommunicationException(exception); + } + } + + private static async Task ReadBodyLengthFromSourceAsync(IMqttCommunicationChannel source) { // Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. var multiplier = 1; @@ -91,7 +92,7 @@ namespace MQTTnet.Core.Serializer byte encodedByte; do { - encodedByte = await ReadStreamByteAsync().ConfigureAwait(false); + encodedByte = await ReadStreamByteAsync(source).ConfigureAwait(false); value += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128 * 128 * 128) @@ -99,41 +100,7 @@ namespace MQTTnet.Core.Serializer throw new MqttProtocolViolationException("Remaining length is ivalid."); } } while ((encodedByte & 128) != 0); - - _remainingLength = value; - } - - private Task ReadFromSourceAsync(byte[] buffer) - { - try - { - return _source.ReadAsync(buffer); - } - catch (Exception exception) - { - throw new MqttCommunicationException(exception); - } - } - - private async Task ReadStreamByteAsync() - { - var buffer = new byte[1]; - await ReadFromSourceAsync(buffer).ConfigureAwait(false); - return buffer[0]; - } - - private async Task ReadFixedHeaderAsync() - { - FixedHeader = await ReadStreamByteAsync().ConfigureAwait(false); - - var byteReader = new ByteReader(FixedHeader); - byteReader.Read(4); - ControlPacketType = (MqttControlPacketType)byteReader.Read(4); - } - - public void Dispose() - { - _remainingData?.Dispose(); + return value; } } } diff --git a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs index e2e8f44..88a7cfe 100644 --- a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs +++ b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.IO; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -16,79 +18,95 @@ namespace MQTTnet.Core.Serializer public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; - public Task SerializeAsync(MqttBasePacket packet, IMqttCommunicationChannel destination) + public async Task SerializeAsync(MqttBasePacket packet, IMqttCommunicationChannel destination) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (destination == null) throw new ArgumentNullException(nameof(destination)); + using (var stream = new MemoryStream()) + using (var writer = new MqttPacketWriter(stream)) + { + var header = new List(); + header.Add(SerializePacket(packet, writer)); + + var body = stream.ToArray(); + MqttPacketWriter.BuildLengthHeader(body.Length, header); + + await destination.WriteAsync(header.ToArray()).ConfigureAwait(false); + await destination.WriteAsync(body).ConfigureAwait(false); + } + } + + private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter writer) + { if (packet is MqttConnectPacket connectPacket) { - return SerializeAsync(connectPacket, destination); + return Serialize(connectPacket, writer); } if (packet is MqttConnAckPacket connAckPacket) { - return SerializeAsync(connAckPacket, destination); + return Serialize(connAckPacket, writer); } if (packet is MqttDisconnectPacket disconnectPacket) { - return SerializeAsync(disconnectPacket, destination); + return Serialize(disconnectPacket, writer); } if (packet is MqttPingReqPacket pingReqPacket) { - return SerializeAsync(pingReqPacket, destination); + return Serialize(pingReqPacket, writer); } if (packet is MqttPingRespPacket pingRespPacket) { - return SerializeAsync(pingRespPacket, destination); + return Serialize(pingRespPacket, writer); } if (packet is MqttPublishPacket publishPacket) { - return SerializeAsync(publishPacket, destination); + return Serialize(publishPacket, writer); } if (packet is MqttPubAckPacket pubAckPacket) { - return SerializeAsync(pubAckPacket, destination); + return Serialize(pubAckPacket, writer); } if (packet is MqttPubRecPacket pubRecPacket) { - return SerializeAsync(pubRecPacket, destination); + return Serialize(pubRecPacket, writer); } if (packet is MqttPubRelPacket pubRelPacket) { - return SerializeAsync(pubRelPacket, destination); + return Serialize(pubRelPacket, writer); } if (packet is MqttPubCompPacket pubCompPacket) { - return SerializeAsync(pubCompPacket, destination); + return Serialize(pubCompPacket, writer); } if (packet is MqttSubscribePacket subscribePacket) { - return SerializeAsync(subscribePacket, destination); + return Serialize(subscribePacket, writer); } if (packet is MqttSubAckPacket subAckPacket) { - return SerializeAsync(subAckPacket, destination); + return Serialize(subAckPacket, writer); } if (packet is MqttUnsubscribePacket unsubscribePacket) { - return SerializeAsync(unsubscribePacket, destination); + return Serialize(unsubscribePacket, writer); } if (packet is MqttUnsubAckPacket unsubAckPacket) { - return SerializeAsync(unsubAckPacket, destination); + return Serialize(unsubAckPacket, writer); } throw new MqttProtocolViolationException("Packet type invalid."); @@ -97,12 +115,18 @@ namespace MQTTnet.Core.Serializer public async Task DeserializeAsync(IMqttCommunicationChannel source) { if (source == null) throw new ArgumentNullException(nameof(source)); + + var header = await MqttPacketReader.ReadHeaderFromSourceAsync(source).ConfigureAwait(false); - using (var mqttPacketReader = new MqttPacketReader(source)) + var body = new byte[header.BodyLength]; + if (header.BodyLength > 0) { - await mqttPacketReader.ReadToEndAsync().ConfigureAwait(false); + await MqttPacketReader.ReadFromSourceAsync(source, body).ConfigureAwait(false); + } - switch (mqttPacketReader.ControlPacketType) + using (var mqttPacketReader = new MqttPacketReader(new MemoryStream(body), header)) + { + switch (header.ControlPacketType) { case MqttControlPacketType.Connect: { @@ -121,14 +145,14 @@ namespace MQTTnet.Core.Serializer case MqttControlPacketType.Publish: { - return DeserializePublish(mqttPacketReader); + return DeserializePublish(mqttPacketReader, header); } case MqttControlPacketType.PubAck: { return new MqttPubAckPacket { - PacketIdentifier = mqttPacketReader.ReadRemainingDataUShort() + PacketIdentifier = mqttPacketReader.ReadUInt16() }; } @@ -136,7 +160,7 @@ namespace MQTTnet.Core.Serializer { return new MqttPubRecPacket { - PacketIdentifier = mqttPacketReader.ReadRemainingDataUShort() + PacketIdentifier = mqttPacketReader.ReadUInt16() }; } @@ -144,7 +168,7 @@ namespace MQTTnet.Core.Serializer { return new MqttPubRelPacket { - PacketIdentifier = mqttPacketReader.ReadRemainingDataUShort() + PacketIdentifier = mqttPacketReader.ReadUInt16() }; } @@ -152,7 +176,7 @@ namespace MQTTnet.Core.Serializer { return new MqttPubCompPacket { - PacketIdentifier = mqttPacketReader.ReadRemainingDataUShort() + PacketIdentifier = mqttPacketReader.ReadUInt16() }; } @@ -185,13 +209,13 @@ namespace MQTTnet.Core.Serializer { return new MqttUnsubAckPacket { - PacketIdentifier = mqttPacketReader.ReadRemainingDataUShort() + PacketIdentifier = mqttPacketReader.ReadUInt16() }; } default: { - throw new MqttProtocolViolationException($"Packet type ({(int)mqttPacketReader.ControlPacketType}) not supported."); + throw new MqttProtocolViolationException($"Packet type ({(int)header.ControlPacketType}) not supported."); } } } @@ -201,12 +225,12 @@ namespace MQTTnet.Core.Serializer { var packet = new MqttUnsubscribePacket { - PacketIdentifier = reader.ReadRemainingDataUShort(), + PacketIdentifier = reader.ReadUInt16(), }; while (!reader.EndOfRemainingData) { - packet.TopicFilters.Add(reader.ReadRemainingDataStringWithLengthPrefix()); + packet.TopicFilters.Add(reader.ReadStringWithLengthPrefix()); } return packet; @@ -216,32 +240,32 @@ namespace MQTTnet.Core.Serializer { var packet = new MqttSubscribePacket { - PacketIdentifier = reader.ReadRemainingDataUShort() + PacketIdentifier = reader.ReadUInt16() }; while (!reader.EndOfRemainingData) { packet.TopicFilters.Add(new TopicFilter( - reader.ReadRemainingDataStringWithLengthPrefix(), - (MqttQualityOfServiceLevel)reader.ReadRemainingDataByte())); + reader.ReadStringWithLengthPrefix(), + (MqttQualityOfServiceLevel)reader.ReadByte())); } return packet; } - private static MqttBasePacket DeserializePublish(MqttPacketReader reader) + private static MqttBasePacket DeserializePublish(MqttPacketReader reader, MqttPacketHeader mqttPacketHeader) { - var fixedHeader = new ByteReader(reader.FixedHeader); + var fixedHeader = new ByteReader(mqttPacketHeader.FixedHeader); var retain = fixedHeader.Read(); var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2); var dup = fixedHeader.Read(); - var topic = reader.ReadRemainingDataStringWithLengthPrefix(); + var topic = reader.ReadStringWithLengthPrefix(); ushort packetIdentifier = 0; if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - packetIdentifier = reader.ReadRemainingDataUShort(); + packetIdentifier = reader.ReadUInt16(); } var packet = new MqttPublishPacket @@ -259,13 +283,13 @@ namespace MQTTnet.Core.Serializer private static MqttBasePacket DeserializeConnect(MqttPacketReader reader) { - reader.ReadRemainingData(2); // Skip 2 bytes + reader.ReadBytes(2); // Skip 2 bytes MqttProtocolVersion protocolVersion; - var protocolName = reader.ReadRemainingData(4); + var protocolName = reader.ReadBytes(4); if (protocolName.SequenceEqual(ProtocolVersionV310Name)) { - reader.ReadRemainingData(2); + reader.ReadBytes(2); protocolVersion = MqttProtocolVersion.V310; } else if (protocolName.SequenceEqual(ProtocolVersionV311Name)) @@ -277,8 +301,8 @@ namespace MQTTnet.Core.Serializer throw new MqttProtocolViolationException("Protocol name is not supported."); } - var protocolLevel = reader.ReadRemainingDataByte(); - var connectFlags = reader.ReadRemainingDataByte(); + var protocolLevel = reader.ReadByte(); + var connectFlags = reader.ReadByte(); var connectFlagsReader = new ByteReader(connectFlags); connectFlagsReader.Read(); // Reserved. @@ -295,26 +319,26 @@ namespace MQTTnet.Core.Serializer var passwordFlag = connectFlagsReader.Read(); var usernameFlag = connectFlagsReader.Read(); - packet.KeepAlivePeriod = reader.ReadRemainingDataUShort(); - packet.ClientId = reader.ReadRemainingDataStringWithLengthPrefix(); + packet.KeepAlivePeriod = reader.ReadUInt16(); + packet.ClientId = reader.ReadStringWithLengthPrefix(); if (willFlag) { packet.WillMessage = new MqttApplicationMessage( - reader.ReadRemainingDataStringWithLengthPrefix(), - reader.ReadRemainingDataWithLengthPrefix(), + reader.ReadStringWithLengthPrefix(), + reader.ReadWithLengthPrefix(), (MqttQualityOfServiceLevel)willQoS, willRetain); } if (usernameFlag) { - packet.Username = reader.ReadRemainingDataStringWithLengthPrefix(); + packet.Username = reader.ReadStringWithLengthPrefix(); } if (passwordFlag) { - packet.Password = reader.ReadRemainingDataStringWithLengthPrefix(); + packet.Password = reader.ReadStringWithLengthPrefix(); } ValidateConnectPacket(packet); @@ -325,12 +349,12 @@ namespace MQTTnet.Core.Serializer { var packet = new MqttSubAckPacket { - PacketIdentifier = reader.ReadRemainingDataUShort() + PacketIdentifier = reader.ReadUInt16() }; while (!reader.EndOfRemainingData) { - packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)reader.ReadRemainingDataByte()); + packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)reader.ReadByte()); } return packet; @@ -338,8 +362,8 @@ namespace MQTTnet.Core.Serializer private static MqttBasePacket DeserializeConnAck(MqttPacketReader reader) { - var variableHeader1 = reader.ReadRemainingDataByte(); - var variableHeader2 = reader.ReadRemainingDataByte(); + var variableHeader1 = reader.ReadByte(); + var variableHeader2 = reader.ReadByte(); var packet = new MqttConnAckPacket { @@ -366,260 +390,212 @@ namespace MQTTnet.Core.Serializer } } - private Task SerializeAsync(MqttConnectPacket packet, IMqttCommunicationChannel destination) + private byte Serialize(MqttConnectPacket packet, MqttPacketWriter writer) { ValidateConnectPacket(packet); - using (var output = new MqttPacketWriter()) + // Write variable header + writer.Write(0x00, 0x04); // 3.1.2.1 Protocol Name + if (ProtocolVersion == MqttProtocolVersion.V311) { - // Write variable header - output.Write(0x00, 0x04); // 3.1.2.1 Protocol Name - if (ProtocolVersion == MqttProtocolVersion.V311) - { - output.Write(ProtocolVersionV311Name); - output.Write(0x04); // 3.1.2.2 Protocol Level (4) - } - else - { - output.Write(ProtocolVersionV310Name); - output.Write(0x64); - output.Write(0x70); - output.Write(0x03); // Protocol Level (3) - } - - var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags - connectFlags.Write(false); // Reserved - connectFlags.Write(packet.CleanSession); - connectFlags.Write(packet.WillMessage != null); + writer.Write(ProtocolVersionV311Name); + writer.Write(0x04); // 3.1.2.2 Protocol Level (4) + } + else + { + writer.Write(ProtocolVersionV310Name); + writer.Write(0x64); + writer.Write(0x70); + writer.Write(0x03); // Protocol Level (3) + } - if (packet.WillMessage != null) - { - connectFlags.Write((int)packet.WillMessage.QualityOfServiceLevel, 2); - connectFlags.Write(packet.WillMessage.Retain); - } - else - { - connectFlags.Write(0, 2); - connectFlags.Write(false); - } - - connectFlags.Write(packet.Password != null); - connectFlags.Write(packet.Username != null); + var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags + connectFlags.Write(false); // Reserved + connectFlags.Write(packet.CleanSession); + connectFlags.Write(packet.WillMessage != null); - output.Write(connectFlags); - output.Write(packet.KeepAlivePeriod); - output.WriteWithLengthPrefix(packet.ClientId); + if (packet.WillMessage != null) + { + connectFlags.Write((int)packet.WillMessage.QualityOfServiceLevel, 2); + connectFlags.Write(packet.WillMessage.Retain); + } + else + { + connectFlags.Write(0, 2); + connectFlags.Write(false); + } - if (packet.WillMessage != null) - { - output.WriteWithLengthPrefix(packet.WillMessage.Topic); - output.WriteWithLengthPrefix(packet.WillMessage.Payload); - } + connectFlags.Write(packet.Password != null); + connectFlags.Write(packet.Username != null); - if (packet.Username != null) - { - output.WriteWithLengthPrefix(packet.Username); - } + writer.Write(connectFlags); + writer.Write(packet.KeepAlivePeriod); + writer.WriteWithLengthPrefix(packet.ClientId); - if (packet.Password != null) - { - output.WriteWithLengthPrefix(packet.Password); - } + if (packet.WillMessage != null) + { + writer.WriteWithLengthPrefix(packet.WillMessage.Topic); + writer.WriteWithLengthPrefix(packet.WillMessage.Payload); + } + + if (packet.Username != null) + { + writer.WriteWithLengthPrefix(packet.Username); + } - output.InjectFixedHeader(MqttControlPacketType.Connect); - return output.WriteToAsync(destination); + if (packet.Password != null) + { + writer.WriteWithLengthPrefix(packet.Password); } + + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - private Task SerializeAsync(MqttConnAckPacket packet, IMqttCommunicationChannel destination) + private byte Serialize(MqttConnAckPacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) + var connectAcknowledgeFlags = new ByteWriter(); + + if (ProtocolVersion == MqttProtocolVersion.V311) { - var connectAcknowledgeFlags = new ByteWriter(); + connectAcknowledgeFlags.Write(packet.IsSessionPresent); + } - if (ProtocolVersion == MqttProtocolVersion.V311) - { - connectAcknowledgeFlags.Write(packet.IsSessionPresent); - } - - output.Write(connectAcknowledgeFlags); - output.Write((byte)packet.ConnectReturnCode); + writer.Write(connectAcknowledgeFlags); + writer.Write((byte)packet.ConnectReturnCode); - output.InjectFixedHeader(MqttControlPacketType.ConnAck); - return output.WriteToAsync(destination); - } + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); } - private static async Task SerializeAsync(MqttPubRelPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttPubRelPacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - output.InjectFixedHeader(MqttControlPacketType.PubRel, 0x02); - await output.WriteToAsync(destination).ConfigureAwait(false); - } + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } - private static Task SerializeAsync(MqttDisconnectPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttDisconnectPacket packet, MqttPacketWriter writer) { - return SerializeEmptyPacketAsync(MqttControlPacketType.Disconnect, destination); + return SerializeEmptyPacketAsync(MqttControlPacketType.Disconnect, writer); } - private static Task SerializeAsync(MqttPingReqPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttPingReqPacket packet, MqttPacketWriter writer) { - return SerializeEmptyPacketAsync(MqttControlPacketType.PingReq, destination); + return SerializeEmptyPacketAsync(MqttControlPacketType.PingReq, writer); } - private static Task SerializeAsync(MqttPingRespPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttPingRespPacket packet, MqttPacketWriter writer) { - return SerializeEmptyPacketAsync(MqttControlPacketType.PingResp, destination); + return SerializeEmptyPacketAsync(MqttControlPacketType.PingResp, writer); } - private static Task SerializeAsync(MqttPublishPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttPublishPacket packet, MqttPacketWriter writer) { ValidatePublishPacket(packet); - using (var output = new MqttPacketWriter()) - { - output.WriteWithLengthPrefix(packet.Topic); + writer.WriteWithLengthPrefix(packet.Topic); - if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) - { - output.Write(packet.PacketIdentifier); - } - else + if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) + { + writer.Write(packet.PacketIdentifier); + } + else + { + if (packet.PacketIdentifier > 0) { - if (packet.PacketIdentifier > 0) - { - throw new MqttProtocolViolationException("Packet identifier must be empty if QoS == 0 [MQTT-2.3.1-5]."); - } + throw new MqttProtocolViolationException("Packet identifier must be empty if QoS == 0 [MQTT-2.3.1-5]."); } + } - if (packet.Payload?.Length > 0) - { - output.Write(packet.Payload); - } + if (packet.Payload?.Length > 0) + { + writer.Write(packet.Payload); + } - var fixedHeader = new ByteWriter(); - fixedHeader.Write(packet.Retain); - fixedHeader.Write((byte)packet.QualityOfServiceLevel, 2); - fixedHeader.Write(packet.Dup); + var fixedHeader = new ByteWriter(); + fixedHeader.Write(packet.Retain); + fixedHeader.Write((byte)packet.QualityOfServiceLevel, 2); + fixedHeader.Write(packet.Dup); - output.InjectFixedHeader(MqttControlPacketType.Publish, fixedHeader.Value); - return output.WriteToAsync(destination); - } + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader.Value); } - private static Task SerializeAsync(MqttPubAckPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - output.InjectFixedHeader(MqttControlPacketType.PubAck); - return output.WriteToAsync(destination); - } + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } - private static Task SerializeAsync(MqttPubRecPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttPubRecPacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - output.InjectFixedHeader(MqttControlPacketType.PubRec); - return output.WriteToAsync(destination); - } + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } - private static Task SerializeAsync(MqttPubCompPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttPubCompPacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - output.InjectFixedHeader(MqttControlPacketType.PubComp); - return output.WriteToAsync(destination); - } + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } - private static Task SerializeAsync(MqttSubscribePacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttSubscribePacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - if (packet.TopicFilters?.Count > 0) + if (packet.TopicFilters?.Count > 0) + { + foreach (var topicFilter in packet.TopicFilters) { - foreach (var topicFilter in packet.TopicFilters) - { - output.WriteWithLengthPrefix(topicFilter.Topic); - output.Write((byte)topicFilter.QualityOfServiceLevel); - } + writer.WriteWithLengthPrefix(topicFilter.Topic); + writer.Write((byte)topicFilter.QualityOfServiceLevel); } - - output.InjectFixedHeader(MqttControlPacketType.Subscribe, 0x02); - return output.WriteToAsync(destination); } + + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); } - private static Task SerializeAsync(MqttSubAckPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttSubAckPacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - if (packet.SubscribeReturnCodes?.Any() == true) + if (packet.SubscribeReturnCodes?.Any() == true) + { + foreach (var packetSubscribeReturnCode in packet.SubscribeReturnCodes) { - foreach (var packetSubscribeReturnCode in packet.SubscribeReturnCodes) - { - output.Write((byte)packetSubscribeReturnCode); - } + writer.Write((byte)packetSubscribeReturnCode); } - - output.InjectFixedHeader(MqttControlPacketType.SubAck); - return output.WriteToAsync(destination); } + + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); } - private static Task SerializeAsync(MqttUnsubscribePacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttUnsubscribePacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - if (packet.TopicFilters?.Any() == true) + if (packet.TopicFilters?.Any() == true) + { + foreach (var topicFilter in packet.TopicFilters) { - foreach (var topicFilter in packet.TopicFilters) - { - output.WriteWithLengthPrefix(topicFilter); - } + writer.WriteWithLengthPrefix(topicFilter); } - - output.InjectFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); - return output.WriteToAsync(destination); } + + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static Task SerializeAsync(MqttUnsubAckPacket packet, IMqttCommunicationChannel destination) + private static byte Serialize(MqttUnsubAckPacket packet, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.Write(packet.PacketIdentifier); + writer.Write(packet.PacketIdentifier); - output.InjectFixedHeader(MqttControlPacketType.UnsubAck); - return output.WriteToAsync(destination); - } + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } - private static Task SerializeEmptyPacketAsync(MqttControlPacketType type, IMqttCommunicationChannel destination) + private static byte SerializeEmptyPacketAsync(MqttControlPacketType type, MqttPacketWriter writer) { - using (var output = new MqttPacketWriter()) - { - output.InjectFixedHeader(type); - return output.WriteToAsync(destination); - } + return MqttPacketWriter.BuildFixedHeader(type); } } } diff --git a/MQTTnet.Core/Serializer/MqttPacketWriter.cs b/MQTTnet.Core/Serializer/MqttPacketWriter.cs index 778f52e..d5aa47a 100644 --- a/MQTTnet.Core/Serializer/MqttPacketWriter.cs +++ b/MQTTnet.Core/Serializer/MqttPacketWriter.cs @@ -1,49 +1,45 @@ using System; using System.IO; using System.Text; -using System.Threading.Tasks; -using MQTTnet.Core.Channel; using MQTTnet.Core.Protocol; +using System.Collections.Generic; namespace MQTTnet.Core.Serializer { - public sealed class MqttPacketWriter : IDisposable + public sealed class MqttPacketWriter : BinaryWriter { - private readonly MemoryStream _buffer = new MemoryStream(1024); - - public void InjectFixedHeader(MqttControlPacketType packetType, byte flags = 0) + public MqttPacketWriter( Stream stream ) + : base(stream) { - var fixedHeader = (int)packetType << 4; - fixedHeader |= flags; - InjectFixedHeader((byte)fixedHeader); + } - public void Write(byte value) + public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0) { - _buffer.WriteByte(value); + var fixedHeader = (int)packetType << 4; + fixedHeader |= flags; + return (byte)fixedHeader; } - - public void Write(ushort value) + + public override void Write(ushort value) { var buffer = BitConverter.GetBytes(value); - _buffer.WriteByte(buffer[1]); - _buffer.WriteByte(buffer[0]); + Write(buffer[1]); + Write(buffer[0]); } - public void Write(ByteWriter value) + public new void Write(params byte[] values) { - if (value == null) throw new ArgumentNullException(nameof(value)); - - _buffer.WriteByte(value.Value); + base.Write(values); } - public void Write(params byte[] value) + public void Write(ByteWriter value) { if (value == null) throw new ArgumentNullException(nameof(value)); - _buffer.Write(value, 0, value.Length); + Write(value.Value); } - + public void WriteWithLengthPrefix(string value) { WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); @@ -57,36 +53,16 @@ namespace MQTTnet.Core.Serializer Write(value); } - public Task WriteToAsync(IMqttCommunicationChannel destination) + public static void BuildLengthHeader(int length, List header) { - if (destination == null) throw new ArgumentNullException(nameof(destination)); - - return destination.WriteAsync(_buffer.ToArray()); - } - - public void Dispose() - { - _buffer?.Dispose(); - } - - private void InjectFixedHeader(byte fixedHeader) - { - if (_buffer.Length == 0) + if (length == 0) { - Write(fixedHeader); - Write(0); + header.Add(0); return; } - var backupBuffer = _buffer.ToArray(); - var remainingLength = (int)_buffer.Length; - - _buffer.SetLength(0); - - _buffer.WriteByte(fixedHeader); - // Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. - var x = remainingLength; + var x = length; do { var encodedByte = x % 128; @@ -96,10 +72,8 @@ namespace MQTTnet.Core.Serializer encodedByte = encodedByte | 128; } - _buffer.WriteByte((byte)encodedByte); + header.Add((byte)encodedByte); } while (x > 0); - - _buffer.Write(backupBuffer, 0, backupBuffer.Length); } } }