Browse Source

Add several packet validations.

release/3.x.x
Christian Kratky 6 years ago
parent
commit
62e9a333c7
3 changed files with 114 additions and 82 deletions
  1. +1
    -1
      Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs
  2. +1
    -1
      Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs
  3. +112
    -80
      Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs

+ 1
- 1
Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs View File

@@ -137,7 +137,7 @@ namespace MQTTnet.Adapter
if (header.BodyLength == 0)
{
return new ReceivedMqttPacket(header, new MemoryStream(new byte[0], false));
return new ReceivedMqttPacket(header, null);
}

var body = new MemoryStream(header.BodyLength);


+ 1
- 1
Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs View File

@@ -9,7 +9,7 @@ namespace MQTTnet.Adapter
public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body)
{
Header = header ?? throw new ArgumentNullException(nameof(header));
Body = body ?? throw new ArgumentNullException(nameof(body));
Body = body;
}

public MqttPacketHeader Header { get; }


+ 112
- 80
Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs View File

@@ -4,15 +4,11 @@ using MQTTnet.Protocol;
using System;
using System.IO;
using System.Linq;
using System.Text;

namespace MQTTnet.Serializer
{
public sealed class MqttPacketSerializer : IMqttPacketSerializer
{
private static byte[] ProtocolVersionV311Name { get; } = Encoding.UTF8.GetBytes("MQTT");
private static byte[] ProtocolVersionV310Name { get; } = Encoding.UTF8.GetBytes("MQIsdp");

public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311;

public ArraySegment<byte> Serialize(MqttBasePacket packet)
@@ -67,116 +63,132 @@ namespace MQTTnet.Serializer
}
}

public MqttBasePacket Deserialize(MqttPacketHeader header, Stream stream)
public MqttBasePacket Deserialize(MqttPacketHeader header, Stream body)
{
if (header == null) throw new ArgumentNullException(nameof(header));
if (stream == null) throw new ArgumentNullException(nameof(stream));
if (body == null) throw new ArgumentNullException(nameof(body));

switch (header.ControlPacketType)
{
case MqttControlPacketType.Connect: return DeserializeConnect(stream);
case MqttControlPacketType.ConnAck: return DeserializeConnAck(stream);
case MqttControlPacketType.Connect: return DeserializeConnect(body);
case MqttControlPacketType.ConnAck: return DeserializeConnAck(body);
case MqttControlPacketType.Disconnect: return new MqttDisconnectPacket();
case MqttControlPacketType.Publish: return DeserializePublish(stream, header);
case MqttControlPacketType.PubAck: return DeserializePubAck(stream);
case MqttControlPacketType.PubRec: return DeserializePubRec(stream);
case MqttControlPacketType.PubRel: return DeserializePubRel(stream);
case MqttControlPacketType.PubComp: return DeserializePubComp(stream);
case MqttControlPacketType.Publish: return DeserializePublish(header, body);
case MqttControlPacketType.PubAck: return DeserializePubAck(body);
case MqttControlPacketType.PubRec: return DeserializePubRec(body);
case MqttControlPacketType.PubRel: return DeserializePubRel(body);
case MqttControlPacketType.PubComp: return DeserializePubComp(body);
case MqttControlPacketType.PingReq: return new MqttPingReqPacket();
case MqttControlPacketType.PingResp: return new MqttPingRespPacket();
case MqttControlPacketType.Subscribe: return DeserializeSubscribe(stream, header);
case MqttControlPacketType.SubAck: return DeserializeSubAck(stream, header);
case MqttControlPacketType.Unsubscibe: return DeserializeUnsubscribe(stream, header);
case MqttControlPacketType.UnsubAck: return DeserializeUnsubAck(stream);
case MqttControlPacketType.Subscribe: return DeserializeSubscribe(header, body);
case MqttControlPacketType.SubAck: return DeserializeSubAck(header, body);
case MqttControlPacketType.Unsubscibe: return DeserializeUnsubscribe(header, body);
case MqttControlPacketType.UnsubAck: return DeserializeUnsubAck(body);
default: throw new MqttProtocolViolationException($"Packet type ({(int)header.ControlPacketType}) not supported.");
}
}

private static MqttBasePacket DeserializeUnsubAck(Stream stream)
private static MqttBasePacket DeserializeUnsubAck(Stream body)
{
ThrowIfBodyIsEmpty(body);

return new MqttUnsubAckPacket
{
PacketIdentifier = stream.ReadUInt16()
PacketIdentifier = body.ReadUInt16()
};
}

private static MqttBasePacket DeserializePubComp(Stream stream)
private static MqttBasePacket DeserializePubComp(Stream body)
{
ThrowIfBodyIsEmpty(body);

return new MqttPubCompPacket
{
PacketIdentifier = stream.ReadUInt16()
PacketIdentifier = body.ReadUInt16()
};
}

private static MqttBasePacket DeserializePubRel(Stream stream)
private static MqttBasePacket DeserializePubRel(Stream body)
{
ThrowIfBodyIsEmpty(body);

return new MqttPubRelPacket
{
PacketIdentifier = stream.ReadUInt16()
PacketIdentifier = body.ReadUInt16()
};
}

private static MqttBasePacket DeserializePubRec(Stream stream)
private static MqttBasePacket DeserializePubRec(Stream body)
{
ThrowIfBodyIsEmpty(body);

return new MqttPubRecPacket
{
PacketIdentifier = stream.ReadUInt16()
PacketIdentifier = body.ReadUInt16()
};
}

private static MqttBasePacket DeserializePubAck(Stream stream)
private static MqttBasePacket DeserializePubAck(Stream body)
{
ThrowIfBodyIsEmpty(body);

return new MqttPubAckPacket
{
PacketIdentifier = stream.ReadUInt16()
PacketIdentifier = body.ReadUInt16()
};
}

private static MqttBasePacket DeserializeUnsubscribe(Stream stream, MqttPacketHeader header)
private static MqttBasePacket DeserializeUnsubscribe(MqttPacketHeader header, Stream body)
{
ThrowIfBodyIsEmpty(body);

var packet = new MqttUnsubscribePacket
{
PacketIdentifier = stream.ReadUInt16(),
PacketIdentifier = body.ReadUInt16(),
};

while (stream.Position != header.BodyLength)
while (body.Position != header.BodyLength)
{
packet.TopicFilters.Add(stream.ReadStringWithLengthPrefix());
packet.TopicFilters.Add(body.ReadStringWithLengthPrefix());
}

return packet;
}

private static MqttBasePacket DeserializeSubscribe(Stream stream, MqttPacketHeader header)
private static MqttBasePacket DeserializeSubscribe(MqttPacketHeader header, Stream body)
{
ThrowIfBodyIsEmpty(body);

var packet = new MqttSubscribePacket
{
PacketIdentifier = stream.ReadUInt16()
PacketIdentifier = body.ReadUInt16()
};

while (stream.Position != header.BodyLength)
while (body.Position != header.BodyLength)
{
packet.TopicFilters.Add(new TopicFilter(
stream.ReadStringWithLengthPrefix(),
(MqttQualityOfServiceLevel)stream.ReadByte()));
body.ReadStringWithLengthPrefix(),
(MqttQualityOfServiceLevel)body.ReadByte()));
}

return packet;
}

private static MqttBasePacket DeserializePublish(Stream stream, MqttPacketHeader mqttPacketHeader)
private static MqttBasePacket DeserializePublish(MqttPacketHeader header, Stream body)
{
var fixedHeader = new ByteReader(mqttPacketHeader.FixedHeader);
ThrowIfBodyIsEmpty(body);

var fixedHeader = new ByteReader(header.FixedHeader);
var retain = fixedHeader.Read();
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2);
var dup = fixedHeader.Read();

var topic = stream.ReadStringWithLengthPrefix();
var topic = body.ReadStringWithLengthPrefix();

ushort? packetIdentifier = null;
if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce)
{
packetIdentifier = stream.ReadUInt16();
packetIdentifier = body.ReadUInt16();
}

var packet = new MqttPublishPacket
@@ -184,7 +196,7 @@ namespace MQTTnet.Serializer
PacketIdentifier = packetIdentifier,
Retain = retain,
Topic = topic,
Payload = stream.ReadRemainingData(mqttPacketHeader),
Payload = body.ReadRemainingData(header),
QualityOfServiceLevel = qualityOfServiceLevel,
Dup = dup
};
@@ -192,57 +204,65 @@ namespace MQTTnet.Serializer
return packet;
}

private static MqttBasePacket DeserializeConnect(Stream stream)
private static MqttBasePacket DeserializeConnect(Stream body)
{
stream.ReadBytes(2); // Skip 2 bytes for header and remaining length.
ThrowIfBodyIsEmpty(body);

MqttProtocolVersion protocolVersion;
var protocolName = stream.ReadBytes(4);
var protocolName = body.ReadStringWithLengthPrefix();

if (protocolName.SequenceEqual(ProtocolVersionV311Name))
MqttProtocolVersion protocolVersion;
if (protocolName == "MQTT")
{
var protocolLevel = body.ReadByte();
if (protocolLevel != 4)
{
throw new MqttProtocolViolationException($"Protocol level ({protocolLevel}) not supported for MQTT 3.1.1.");
}

protocolVersion = MqttProtocolVersion.V311;
}
else
else if (protocolName == "MQIsdp")
{
protocolName = protocolName.Concat(stream.ReadBytes(2)).ToArray();
if (protocolName.SequenceEqual(ProtocolVersionV310Name))
var protocolLevel = body.ReadByte();
if (protocolLevel != 3)
{
protocolVersion = MqttProtocolVersion.V310;
throw new MqttProtocolViolationException($"Protocol level ({protocolLevel}) not supported for MQTT 3.1.");
}
else
{
throw new MqttProtocolViolationException("Protocol name is not supported.");
}
}

stream.ReadByte(); // Skip protocol level
var connectFlags = stream.ReadByte();
protocolVersion = MqttProtocolVersion.V310;
}
else
{
throw new MqttProtocolViolationException($"Protocol name ({protocolName}) is not supported.");
}

var connectFlagsReader = new ByteReader(connectFlags);
connectFlagsReader.Read(); // Reserved.
var connectFlags = new ByteReader(body.ReadByte());
if (connectFlags.Read())
{
throw new MqttProtocolViolationException("The first bit of the Connect Flags must be set to 0.");
}

var packet = new MqttConnectPacket
{
ProtocolVersion = protocolVersion,
CleanSession = connectFlagsReader.Read()
CleanSession = connectFlags.Read()
};

var willFlag = connectFlagsReader.Read();
var willQoS = connectFlagsReader.Read(2);
var willRetain = connectFlagsReader.Read();
var passwordFlag = connectFlagsReader.Read();
var usernameFlag = connectFlagsReader.Read();
var willFlag = connectFlags.Read();
var willQoS = connectFlags.Read(2);
var willRetain = connectFlags.Read();
var passwordFlag = connectFlags.Read();
var usernameFlag = connectFlags.Read();

packet.KeepAlivePeriod = stream.ReadUInt16();
packet.ClientId = stream.ReadStringWithLengthPrefix();
packet.KeepAlivePeriod = body.ReadUInt16();
packet.ClientId = body.ReadStringWithLengthPrefix();

if (willFlag)
{
packet.WillMessage = new MqttApplicationMessage
{
Topic = stream.ReadStringWithLengthPrefix(),
Payload = stream.ReadWithLengthPrefix(),
Topic = body.ReadStringWithLengthPrefix(),
Payload = body.ReadWithLengthPrefix(),
QualityOfServiceLevel = (MqttQualityOfServiceLevel)willQoS,
Retain = willRetain
};
@@ -250,45 +270,49 @@ namespace MQTTnet.Serializer

if (usernameFlag)
{
packet.Username = stream.ReadStringWithLengthPrefix();
packet.Username = body.ReadStringWithLengthPrefix();
}

if (passwordFlag)
{
packet.Password = stream.ReadStringWithLengthPrefix();
packet.Password = body.ReadStringWithLengthPrefix();
}

ValidateConnectPacket(packet);
return packet;
}

private static MqttBasePacket DeserializeSubAck(Stream stream, MqttPacketHeader header)
private static MqttBasePacket DeserializeSubAck(MqttPacketHeader header, Stream body)
{
ThrowIfBodyIsEmpty(body);

var packet = new MqttSubAckPacket
{
PacketIdentifier = stream.ReadUInt16()
PacketIdentifier = body.ReadUInt16()
};

while (stream.Position != header.BodyLength)
while (body.Position != header.BodyLength)
{
packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)stream.ReadByte());
packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)body.ReadByte());
}

return packet;
}

private MqttBasePacket DeserializeConnAck(Stream stream)
private MqttBasePacket DeserializeConnAck(Stream body)
{
ThrowIfBodyIsEmpty(body);

var packet = new MqttConnAckPacket();

var firstByteReader = new ByteReader(stream.ReadByte());
var firstByteReader = new ByteReader(body.ReadByte());

if (ProtocolVersion == MqttProtocolVersion.V311)
{
packet.IsSessionPresent = firstByteReader.Read();
}

packet.ConnectReturnCode = (MqttConnectReturnCode)stream.ReadByte();
packet.ConnectReturnCode = (MqttConnectReturnCode)body.ReadByte();

return packet;
}
@@ -320,12 +344,12 @@ namespace MQTTnet.Serializer
// Write variable header
if (ProtocolVersion == MqttProtocolVersion.V311)
{
stream.WriteWithLengthPrefix(ProtocolVersionV311Name);
stream.WriteWithLengthPrefix("MQTT");
stream.WriteByte(0x04); // 3.1.2.2 Protocol Level 4
}
else
{
stream.WriteWithLengthPrefix(ProtocolVersionV310Name);
stream.WriteWithLengthPrefix("MQIsdp");
stream.WriteByte(0x03); // Protocol Level 3
}

@@ -571,5 +595,13 @@ namespace MQTTnet.Serializer
{
return MqttPacketWriter.BuildFixedHeader(type);
}

private static void ThrowIfBodyIsEmpty(Stream body)
{
if (body == null || body.Length == 0)
{
throw new MqttProtocolViolationException("Data from the body is required but not present.");
}
}
}
}

Loading…
Cancel
Save