diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index ee2f586..f6abb8d 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -137,7 +137,7 @@ namespace MQTTnet.Adapter if (header.BodyLength == 0) { - return new ReceivedMqttPacket(header, new MemoryStream(new byte[0], false)); + return new ReceivedMqttPacket(header, null); } var body = new MemoryStream(header.BodyLength); diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs b/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs index a44fb54..24651fe 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Adapter public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body) { Header = header ?? throw new ArgumentNullException(nameof(header)); - Body = body ?? throw new ArgumentNullException(nameof(body)); + Body = body; } public MqttPacketHeader Header { get; } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 943d49e..75da85d 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -4,15 +4,11 @@ using MQTTnet.Protocol; using System; using System.IO; using System.Linq; -using System.Text; namespace MQTTnet.Serializer { public sealed class MqttPacketSerializer : IMqttPacketSerializer { - private static byte[] ProtocolVersionV311Name { get; } = Encoding.UTF8.GetBytes("MQTT"); - private static byte[] ProtocolVersionV310Name { get; } = Encoding.UTF8.GetBytes("MQIsdp"); - public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; public ArraySegment Serialize(MqttBasePacket packet) @@ -67,116 +63,132 @@ namespace MQTTnet.Serializer } } - public MqttBasePacket Deserialize(MqttPacketHeader header, Stream stream) + public MqttBasePacket Deserialize(MqttPacketHeader header, Stream body) { if (header == null) throw new ArgumentNullException(nameof(header)); - if (stream == null) throw new ArgumentNullException(nameof(stream)); + if (body == null) throw new ArgumentNullException(nameof(body)); switch (header.ControlPacketType) { - case MqttControlPacketType.Connect: return DeserializeConnect(stream); - case MqttControlPacketType.ConnAck: return DeserializeConnAck(stream); + case MqttControlPacketType.Connect: return DeserializeConnect(body); + case MqttControlPacketType.ConnAck: return DeserializeConnAck(body); case MqttControlPacketType.Disconnect: return new MqttDisconnectPacket(); - case MqttControlPacketType.Publish: return DeserializePublish(stream, header); - case MqttControlPacketType.PubAck: return DeserializePubAck(stream); - case MqttControlPacketType.PubRec: return DeserializePubRec(stream); - case MqttControlPacketType.PubRel: return DeserializePubRel(stream); - case MqttControlPacketType.PubComp: return DeserializePubComp(stream); + case MqttControlPacketType.Publish: return DeserializePublish(header, body); + case MqttControlPacketType.PubAck: return DeserializePubAck(body); + case MqttControlPacketType.PubRec: return DeserializePubRec(body); + case MqttControlPacketType.PubRel: return DeserializePubRel(body); + case MqttControlPacketType.PubComp: return DeserializePubComp(body); case MqttControlPacketType.PingReq: return new MqttPingReqPacket(); case MqttControlPacketType.PingResp: return new MqttPingRespPacket(); - case MqttControlPacketType.Subscribe: return DeserializeSubscribe(stream, header); - case MqttControlPacketType.SubAck: return DeserializeSubAck(stream, header); - case MqttControlPacketType.Unsubscibe: return DeserializeUnsubscribe(stream, header); - case MqttControlPacketType.UnsubAck: return DeserializeUnsubAck(stream); + case MqttControlPacketType.Subscribe: return DeserializeSubscribe(header, body); + case MqttControlPacketType.SubAck: return DeserializeSubAck(header, body); + case MqttControlPacketType.Unsubscibe: return DeserializeUnsubscribe(header, body); + case MqttControlPacketType.UnsubAck: return DeserializeUnsubAck(body); default: throw new MqttProtocolViolationException($"Packet type ({(int)header.ControlPacketType}) not supported."); } } - private static MqttBasePacket DeserializeUnsubAck(Stream stream) + private static MqttBasePacket DeserializeUnsubAck(Stream body) { + ThrowIfBodyIsEmpty(body); + return new MqttUnsubAckPacket { - PacketIdentifier = stream.ReadUInt16() + PacketIdentifier = body.ReadUInt16() }; } - private static MqttBasePacket DeserializePubComp(Stream stream) + private static MqttBasePacket DeserializePubComp(Stream body) { + ThrowIfBodyIsEmpty(body); + return new MqttPubCompPacket { - PacketIdentifier = stream.ReadUInt16() + PacketIdentifier = body.ReadUInt16() }; } - private static MqttBasePacket DeserializePubRel(Stream stream) + private static MqttBasePacket DeserializePubRel(Stream body) { + ThrowIfBodyIsEmpty(body); + return new MqttPubRelPacket { - PacketIdentifier = stream.ReadUInt16() + PacketIdentifier = body.ReadUInt16() }; } - private static MqttBasePacket DeserializePubRec(Stream stream) + private static MqttBasePacket DeserializePubRec(Stream body) { + ThrowIfBodyIsEmpty(body); + return new MqttPubRecPacket { - PacketIdentifier = stream.ReadUInt16() + PacketIdentifier = body.ReadUInt16() }; } - private static MqttBasePacket DeserializePubAck(Stream stream) + private static MqttBasePacket DeserializePubAck(Stream body) { + ThrowIfBodyIsEmpty(body); + return new MqttPubAckPacket { - PacketIdentifier = stream.ReadUInt16() + PacketIdentifier = body.ReadUInt16() }; } - private static MqttBasePacket DeserializeUnsubscribe(Stream stream, MqttPacketHeader header) + private static MqttBasePacket DeserializeUnsubscribe(MqttPacketHeader header, Stream body) { + ThrowIfBodyIsEmpty(body); + var packet = new MqttUnsubscribePacket { - PacketIdentifier = stream.ReadUInt16(), + PacketIdentifier = body.ReadUInt16(), }; - while (stream.Position != header.BodyLength) + while (body.Position != header.BodyLength) { - packet.TopicFilters.Add(stream.ReadStringWithLengthPrefix()); + packet.TopicFilters.Add(body.ReadStringWithLengthPrefix()); } return packet; } - private static MqttBasePacket DeserializeSubscribe(Stream stream, MqttPacketHeader header) + private static MqttBasePacket DeserializeSubscribe(MqttPacketHeader header, Stream body) { + ThrowIfBodyIsEmpty(body); + var packet = new MqttSubscribePacket { - PacketIdentifier = stream.ReadUInt16() + PacketIdentifier = body.ReadUInt16() }; - while (stream.Position != header.BodyLength) + while (body.Position != header.BodyLength) { packet.TopicFilters.Add(new TopicFilter( - stream.ReadStringWithLengthPrefix(), - (MqttQualityOfServiceLevel)stream.ReadByte())); + body.ReadStringWithLengthPrefix(), + (MqttQualityOfServiceLevel)body.ReadByte())); } return packet; } - private static MqttBasePacket DeserializePublish(Stream stream, MqttPacketHeader mqttPacketHeader) + private static MqttBasePacket DeserializePublish(MqttPacketHeader header, Stream body) { - var fixedHeader = new ByteReader(mqttPacketHeader.FixedHeader); + ThrowIfBodyIsEmpty(body); + + var fixedHeader = new ByteReader(header.FixedHeader); var retain = fixedHeader.Read(); var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2); var dup = fixedHeader.Read(); - var topic = stream.ReadStringWithLengthPrefix(); + var topic = body.ReadStringWithLengthPrefix(); ushort? packetIdentifier = null; if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - packetIdentifier = stream.ReadUInt16(); + packetIdentifier = body.ReadUInt16(); } var packet = new MqttPublishPacket @@ -184,7 +196,7 @@ namespace MQTTnet.Serializer PacketIdentifier = packetIdentifier, Retain = retain, Topic = topic, - Payload = stream.ReadRemainingData(mqttPacketHeader), + Payload = body.ReadRemainingData(header), QualityOfServiceLevel = qualityOfServiceLevel, Dup = dup }; @@ -192,57 +204,65 @@ namespace MQTTnet.Serializer return packet; } - private static MqttBasePacket DeserializeConnect(Stream stream) + private static MqttBasePacket DeserializeConnect(Stream body) { - stream.ReadBytes(2); // Skip 2 bytes for header and remaining length. + ThrowIfBodyIsEmpty(body); - MqttProtocolVersion protocolVersion; - var protocolName = stream.ReadBytes(4); + var protocolName = body.ReadStringWithLengthPrefix(); - if (protocolName.SequenceEqual(ProtocolVersionV311Name)) + MqttProtocolVersion protocolVersion; + if (protocolName == "MQTT") { + var protocolLevel = body.ReadByte(); + if (protocolLevel != 4) + { + throw new MqttProtocolViolationException($"Protocol level ({protocolLevel}) not supported for MQTT 3.1.1."); + } + protocolVersion = MqttProtocolVersion.V311; } - else + else if (protocolName == "MQIsdp") { - protocolName = protocolName.Concat(stream.ReadBytes(2)).ToArray(); - if (protocolName.SequenceEqual(ProtocolVersionV310Name)) + var protocolLevel = body.ReadByte(); + if (protocolLevel != 3) { - protocolVersion = MqttProtocolVersion.V310; + throw new MqttProtocolViolationException($"Protocol level ({protocolLevel}) not supported for MQTT 3.1."); } - else - { - throw new MqttProtocolViolationException("Protocol name is not supported."); - } - } - stream.ReadByte(); // Skip protocol level - var connectFlags = stream.ReadByte(); + protocolVersion = MqttProtocolVersion.V310; + } + else + { + throw new MqttProtocolViolationException($"Protocol name ({protocolName}) is not supported."); + } - var connectFlagsReader = new ByteReader(connectFlags); - connectFlagsReader.Read(); // Reserved. + var connectFlags = new ByteReader(body.ReadByte()); + if (connectFlags.Read()) + { + throw new MqttProtocolViolationException("The first bit of the Connect Flags must be set to 0."); + } var packet = new MqttConnectPacket { ProtocolVersion = protocolVersion, - CleanSession = connectFlagsReader.Read() + CleanSession = connectFlags.Read() }; - var willFlag = connectFlagsReader.Read(); - var willQoS = connectFlagsReader.Read(2); - var willRetain = connectFlagsReader.Read(); - var passwordFlag = connectFlagsReader.Read(); - var usernameFlag = connectFlagsReader.Read(); + var willFlag = connectFlags.Read(); + var willQoS = connectFlags.Read(2); + var willRetain = connectFlags.Read(); + var passwordFlag = connectFlags.Read(); + var usernameFlag = connectFlags.Read(); - packet.KeepAlivePeriod = stream.ReadUInt16(); - packet.ClientId = stream.ReadStringWithLengthPrefix(); + packet.KeepAlivePeriod = body.ReadUInt16(); + packet.ClientId = body.ReadStringWithLengthPrefix(); if (willFlag) { packet.WillMessage = new MqttApplicationMessage { - Topic = stream.ReadStringWithLengthPrefix(), - Payload = stream.ReadWithLengthPrefix(), + Topic = body.ReadStringWithLengthPrefix(), + Payload = body.ReadWithLengthPrefix(), QualityOfServiceLevel = (MqttQualityOfServiceLevel)willQoS, Retain = willRetain }; @@ -250,45 +270,49 @@ namespace MQTTnet.Serializer if (usernameFlag) { - packet.Username = stream.ReadStringWithLengthPrefix(); + packet.Username = body.ReadStringWithLengthPrefix(); } if (passwordFlag) { - packet.Password = stream.ReadStringWithLengthPrefix(); + packet.Password = body.ReadStringWithLengthPrefix(); } ValidateConnectPacket(packet); return packet; } - private static MqttBasePacket DeserializeSubAck(Stream stream, MqttPacketHeader header) + private static MqttBasePacket DeserializeSubAck(MqttPacketHeader header, Stream body) { + ThrowIfBodyIsEmpty(body); + var packet = new MqttSubAckPacket { - PacketIdentifier = stream.ReadUInt16() + PacketIdentifier = body.ReadUInt16() }; - while (stream.Position != header.BodyLength) + while (body.Position != header.BodyLength) { - packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)stream.ReadByte()); + packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)body.ReadByte()); } return packet; } - private MqttBasePacket DeserializeConnAck(Stream stream) + private MqttBasePacket DeserializeConnAck(Stream body) { + ThrowIfBodyIsEmpty(body); + var packet = new MqttConnAckPacket(); - var firstByteReader = new ByteReader(stream.ReadByte()); + var firstByteReader = new ByteReader(body.ReadByte()); if (ProtocolVersion == MqttProtocolVersion.V311) { packet.IsSessionPresent = firstByteReader.Read(); } - packet.ConnectReturnCode = (MqttConnectReturnCode)stream.ReadByte(); + packet.ConnectReturnCode = (MqttConnectReturnCode)body.ReadByte(); return packet; } @@ -320,12 +344,12 @@ namespace MQTTnet.Serializer // Write variable header if (ProtocolVersion == MqttProtocolVersion.V311) { - stream.WriteWithLengthPrefix(ProtocolVersionV311Name); + stream.WriteWithLengthPrefix("MQTT"); stream.WriteByte(0x04); // 3.1.2.2 Protocol Level 4 } else { - stream.WriteWithLengthPrefix(ProtocolVersionV310Name); + stream.WriteWithLengthPrefix("MQIsdp"); stream.WriteByte(0x03); // Protocol Level 3 } @@ -571,5 +595,13 @@ namespace MQTTnet.Serializer { return MqttPacketWriter.BuildFixedHeader(type); } + + private static void ThrowIfBodyIsEmpty(Stream body) + { + if (body == null || body.Length == 0) + { + throw new MqttProtocolViolationException("Data from the body is required but not present."); + } + } } }