diff --git a/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs b/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs index b1d1a38..79eb2c4 100644 --- a/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/MqttApplicationMessageExtensions.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Internal { - internal static class MqttApplicationMessageExtensions + public static class MqttApplicationMessageExtensions { public static MqttApplicationMessage ToApplicationMessage(this MqttPublishPacket publishPacket) { diff --git a/Tests/MQTTnet.Core.Tests/TestMqttChannel.cs b/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs similarity index 96% rename from Tests/MQTTnet.Core.Tests/TestMqttChannel.cs rename to Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs index 2b4914b..b380b08 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs @@ -3,7 +3,7 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; -namespace MQTTnet.Core.Tests +namespace MQTTnet.Core.Internal { public class TestMqttChannel : IMqttChannel { diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs index 4c43579..4123924 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs @@ -10,18 +10,8 @@ using MQTTnet.Protocol; namespace MQTTnet.Serializer { - public sealed class MqttPacketReader : BinaryReader + public static class MqttPacketReader { - private readonly MqttPacketHeader _header; - - public MqttPacketReader(MqttPacketHeader header, Stream bodyStream) - : base(bodyStream, Encoding.UTF8, true) - { - _header = header; - } - - public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; - public static async Task ReadHeaderAsync(IMqttChannel stream, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) @@ -51,9 +41,9 @@ namespace MQTTnet.Serializer }; } - public override ushort ReadUInt16() + public static ushort ReadUInt16(this Stream stream) { - var buffer = ReadBytes(2); + var buffer = stream.ReadBytes(2); var temp = buffer[0]; buffer[0] = buffer[1]; @@ -62,9 +52,9 @@ namespace MQTTnet.Serializer return BitConverter.ToUInt16(buffer, 0); } - public string ReadStringWithLengthPrefix() + public static string ReadStringWithLengthPrefix(this Stream stream) { - var buffer = ReadWithLengthPrefix(); + var buffer = stream.ReadWithLengthPrefix(); if (buffer.Length == 0) { return string.Empty; @@ -73,20 +63,27 @@ namespace MQTTnet.Serializer return Encoding.UTF8.GetString(buffer, 0, buffer.Length); } - public byte[] ReadWithLengthPrefix() + public static byte[] ReadWithLengthPrefix(this Stream stream) { - var length = ReadUInt16(); + var length = stream.ReadUInt16(); if (length == 0) { return new byte[0]; } - return ReadBytes(length); + return stream.ReadBytes(length); + } + + public static byte[] ReadRemainingData(this Stream stream, MqttPacketHeader header) + { + return stream.ReadBytes(header.BodyLength - (int)stream.Position); } - public byte[] ReadRemainingData() + public static byte[] ReadBytes(this Stream stream, int count) { - return ReadBytes(_header.BodyLength - (int)BaseStream.Position); + var buffer = new byte[count]; + stream.Read(buffer, 0, count); + return buffer; } private static async Task ReadBodyLengthAsync(IMqttChannel stream, CancellationToken cancellationToken) diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 18818ce..12330b5 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -20,11 +20,10 @@ namespace MQTTnet.Serializer if (packet == null) throw new ArgumentNullException(nameof(packet)); using (var stream = new MemoryStream(128)) - using (var writer = new MqttPacketWriter(stream)) { // Leave enough head space for max header size (fixed + 4 variable remaining length) stream.Position = 5; - var fixedHeader = SerializePacket(packet, writer); + var fixedHeader = SerializePacket(packet, stream); stream.Position = 1; var remainingLength = MqttPacketWriter.EncodeRemainingLength((int)stream.Length - 5, stream); @@ -35,7 +34,7 @@ namespace MQTTnet.Serializer // Position cursor on correct offset on beginining of array (has leading 0x0) stream.Position = headerOffset; - writer.Write(fixedHeader); + stream.WriteByte(fixedHeader); #if NET461 || NET452 || NETSTANDARD2_0 var buffer = stream.GetBuffer(); @@ -46,146 +45,138 @@ namespace MQTTnet.Serializer } } - public MqttBasePacket Deserialize(MqttPacketHeader header, Stream body) - { - if (header == null) throw new ArgumentNullException(nameof(header)); - if (body == null) throw new ArgumentNullException(nameof(body)); - - using (var reader = new MqttPacketReader(header, body)) - { - return Deserialize(header, reader); - } - } - - private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter writer) + private byte SerializePacket(MqttBasePacket packet, Stream stream) { switch (packet) { - case MqttConnectPacket connectPacket: return Serialize(connectPacket, writer); - case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, writer); + case MqttConnectPacket connectPacket: return Serialize(connectPacket, stream); + case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, stream); case MqttDisconnectPacket _: return SerializeEmptyPacket(MqttControlPacketType.Disconnect); case MqttPingReqPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingReq); case MqttPingRespPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingResp); - case MqttPublishPacket publishPacket: return Serialize(publishPacket, writer); - case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, writer); - case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, writer); - case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, writer); - case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, writer); - case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, writer); - case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, writer); - case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, writer); - case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, writer); + case MqttPublishPacket publishPacket: return Serialize(publishPacket, stream); + case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, stream); + case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, stream); + case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, stream); + case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, stream); + case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, stream); + case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, stream); + case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, stream); + case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, stream); default: throw new MqttProtocolViolationException("Packet type invalid."); } } - private MqttBasePacket Deserialize(MqttPacketHeader header, MqttPacketReader reader) + public MqttBasePacket Deserialize(MqttPacketHeader header, Stream stream) { + if (header == null) throw new ArgumentNullException(nameof(header)); + if (stream == null) throw new ArgumentNullException(nameof(stream)); + switch (header.ControlPacketType) { - case MqttControlPacketType.Connect: return DeserializeConnect(reader); - case MqttControlPacketType.ConnAck: return DeserializeConnAck(reader); + case MqttControlPacketType.Connect: return DeserializeConnect(stream); + case MqttControlPacketType.ConnAck: return DeserializeConnAck(stream); case MqttControlPacketType.Disconnect: return new MqttDisconnectPacket(); - case MqttControlPacketType.Publish: return DeserializePublish(reader, header); - case MqttControlPacketType.PubAck: return DeserializePubAck(reader); - case MqttControlPacketType.PubRec: return DeserializePubRec(reader); - case MqttControlPacketType.PubRel: return DeserializePubRel(reader); - case MqttControlPacketType.PubComp: return DeserializePubComp(reader); + case MqttControlPacketType.Publish: return DeserializePublish(stream, header); + case MqttControlPacketType.PubAck: return DeserializePubAck(stream); + case MqttControlPacketType.PubRec: return DeserializePubRec(stream); + case MqttControlPacketType.PubRel: return DeserializePubRel(stream); + case MqttControlPacketType.PubComp: return DeserializePubComp(stream); case MqttControlPacketType.PingReq: return new MqttPingReqPacket(); case MqttControlPacketType.PingResp: return new MqttPingRespPacket(); - case MqttControlPacketType.Subscribe: return DeserializeSubscribe(reader); - case MqttControlPacketType.SubAck: return DeserializeSubAck(reader); - case MqttControlPacketType.Unsubscibe: return DeserializeUnsubscribe(reader); - case MqttControlPacketType.UnsubAck: return DeserializeUnsubAck(reader); + case MqttControlPacketType.Subscribe: return DeserializeSubscribe(stream, header); + case MqttControlPacketType.SubAck: return DeserializeSubAck(stream, header); + case MqttControlPacketType.Unsubscibe: return DeserializeUnsubscribe(stream, header); + case MqttControlPacketType.UnsubAck: return DeserializeUnsubAck(stream); default: throw new MqttProtocolViolationException($"Packet type ({(int)header.ControlPacketType}) not supported."); } } - private static MqttBasePacket DeserializeUnsubAck(MqttPacketReader reader) + private static MqttBasePacket DeserializeUnsubAck(Stream stream) { return new MqttUnsubAckPacket { - PacketIdentifier = reader.ReadUInt16() + PacketIdentifier = stream.ReadUInt16() }; } - private static MqttBasePacket DeserializePubComp(MqttPacketReader reader) + private static MqttBasePacket DeserializePubComp(Stream stream) { return new MqttPubCompPacket { - PacketIdentifier = reader.ReadUInt16() + PacketIdentifier = stream.ReadUInt16() }; } - private static MqttBasePacket DeserializePubRel(MqttPacketReader reader) + private static MqttBasePacket DeserializePubRel(Stream stream) { return new MqttPubRelPacket { - PacketIdentifier = reader.ReadUInt16() + PacketIdentifier = stream.ReadUInt16() }; } - private static MqttBasePacket DeserializePubRec(MqttPacketReader reader) + private static MqttBasePacket DeserializePubRec(Stream stream) { return new MqttPubRecPacket { - PacketIdentifier = reader.ReadUInt16() + PacketIdentifier = stream.ReadUInt16() }; } - private static MqttBasePacket DeserializePubAck(MqttPacketReader reader) + private static MqttBasePacket DeserializePubAck(Stream stream) { return new MqttPubAckPacket { - PacketIdentifier = reader.ReadUInt16() + PacketIdentifier = stream.ReadUInt16() }; } - private static MqttBasePacket DeserializeUnsubscribe(MqttPacketReader reader) + private static MqttBasePacket DeserializeUnsubscribe(Stream stream, MqttPacketHeader header) { var packet = new MqttUnsubscribePacket { - PacketIdentifier = reader.ReadUInt16(), + PacketIdentifier = stream.ReadUInt16(), }; - while (!reader.EndOfRemainingData) + while (stream.Position != header.BodyLength) { - packet.TopicFilters.Add(reader.ReadStringWithLengthPrefix()); + packet.TopicFilters.Add(stream.ReadStringWithLengthPrefix()); } return packet; } - private static MqttBasePacket DeserializeSubscribe(MqttPacketReader reader) + private static MqttBasePacket DeserializeSubscribe(Stream stream, MqttPacketHeader header) { var packet = new MqttSubscribePacket { - PacketIdentifier = reader.ReadUInt16() + PacketIdentifier = stream.ReadUInt16() }; - while (!reader.EndOfRemainingData) + while (stream.Position != header.BodyLength) { packet.TopicFilters.Add(new TopicFilter( - reader.ReadStringWithLengthPrefix(), - (MqttQualityOfServiceLevel)reader.ReadByte())); + stream.ReadStringWithLengthPrefix(), + (MqttQualityOfServiceLevel)stream.ReadByte())); } return packet; } - private static MqttBasePacket DeserializePublish(MqttPacketReader reader, MqttPacketHeader mqttPacketHeader) + private static MqttBasePacket DeserializePublish(Stream stream, MqttPacketHeader mqttPacketHeader) { var fixedHeader = new ByteReader(mqttPacketHeader.FixedHeader); var retain = fixedHeader.Read(); var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2); var dup = fixedHeader.Read(); - var topic = reader.ReadStringWithLengthPrefix(); + var topic = stream.ReadStringWithLengthPrefix(); ushort? packetIdentifier = null; if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - packetIdentifier = reader.ReadUInt16(); + packetIdentifier = stream.ReadUInt16(); } var packet = new MqttPublishPacket @@ -193,7 +184,7 @@ namespace MQTTnet.Serializer PacketIdentifier = packetIdentifier, Retain = retain, Topic = topic, - Payload = reader.ReadRemainingData(), + Payload = stream.ReadRemainingData(mqttPacketHeader), QualityOfServiceLevel = qualityOfServiceLevel, Dup = dup }; @@ -201,12 +192,12 @@ namespace MQTTnet.Serializer return packet; } - private static MqttBasePacket DeserializeConnect(MqttPacketReader reader) + private static MqttBasePacket DeserializeConnect(Stream stream) { - reader.ReadBytes(2); // Skip 2 bytes for header and remaining length. + stream.ReadBytes(2); // Skip 2 bytes for header and remaining length. MqttProtocolVersion protocolVersion; - var protocolName = reader.ReadBytes(4); + var protocolName = stream.ReadBytes(4); if (protocolName.SequenceEqual(ProtocolVersionV311Name)) { @@ -216,7 +207,7 @@ namespace MQTTnet.Serializer { var buffer = new byte[6]; Array.Copy(protocolName, buffer, 4); - protocolName = reader.ReadBytes(2); + protocolName = stream.ReadBytes(2); Array.Copy(protocolName, 0, buffer, 4, 2); if (protocolName.SequenceEqual(ProtocolVersionV310Name)) @@ -229,8 +220,8 @@ namespace MQTTnet.Serializer } } - reader.ReadByte(); // Skip protocol level - var connectFlags = reader.ReadByte(); + stream.ReadByte(); // Skip protocol level + var connectFlags = stream.ReadByte(); var connectFlagsReader = new ByteReader(connectFlags); connectFlagsReader.Read(); // Reserved. @@ -247,15 +238,15 @@ namespace MQTTnet.Serializer var passwordFlag = connectFlagsReader.Read(); var usernameFlag = connectFlagsReader.Read(); - packet.KeepAlivePeriod = reader.ReadUInt16(); - packet.ClientId = reader.ReadStringWithLengthPrefix(); + packet.KeepAlivePeriod = stream.ReadUInt16(); + packet.ClientId = stream.ReadStringWithLengthPrefix(); if (willFlag) { packet.WillMessage = new MqttApplicationMessage { - Topic = reader.ReadStringWithLengthPrefix(), - Payload = reader.ReadWithLengthPrefix(), + Topic = stream.ReadStringWithLengthPrefix(), + Payload = stream.ReadWithLengthPrefix(), QualityOfServiceLevel = (MqttQualityOfServiceLevel)willQoS, Retain = willRetain }; @@ -263,45 +254,45 @@ namespace MQTTnet.Serializer if (usernameFlag) { - packet.Username = reader.ReadStringWithLengthPrefix(); + packet.Username = stream.ReadStringWithLengthPrefix(); } if (passwordFlag) { - packet.Password = reader.ReadStringWithLengthPrefix(); + packet.Password = stream.ReadStringWithLengthPrefix(); } ValidateConnectPacket(packet); return packet; } - private static MqttBasePacket DeserializeSubAck(MqttPacketReader reader) + private static MqttBasePacket DeserializeSubAck(Stream stream, MqttPacketHeader header) { var packet = new MqttSubAckPacket { - PacketIdentifier = reader.ReadUInt16() + PacketIdentifier = stream.ReadUInt16() }; - while (!reader.EndOfRemainingData) + while (stream.Position != header.BodyLength) { - packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)reader.ReadByte()); + packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)stream.ReadByte()); } return packet; } - private MqttBasePacket DeserializeConnAck(MqttPacketReader reader) + private MqttBasePacket DeserializeConnAck(Stream stream) { var packet = new MqttConnAckPacket(); - var firstByteReader = new ByteReader(reader.ReadByte()); + var firstByteReader = new ByteReader(stream.ReadByte()); if (ProtocolVersion == MqttProtocolVersion.V311) { packet.IsSessionPresent = firstByteReader.Read(); } - packet.ConnectReturnCode = (MqttConnectReturnCode)reader.ReadByte(); + packet.ConnectReturnCode = (MqttConnectReturnCode)stream.ReadByte(); return packet; } @@ -326,20 +317,20 @@ namespace MQTTnet.Serializer } } - private byte Serialize(MqttConnectPacket packet, MqttPacketWriter writer) + private byte Serialize(MqttConnectPacket packet, Stream stream) { ValidateConnectPacket(packet); // Write variable header if (ProtocolVersion == MqttProtocolVersion.V311) { - writer.WriteWithLengthPrefix(ProtocolVersionV311Name); - writer.Write(0x04); // 3.1.2.2 Protocol Level 4 + stream.WriteWithLengthPrefix(ProtocolVersionV311Name); + stream.WriteByte(0x04); // 3.1.2.2 Protocol Level 4 } else { - writer.WriteWithLengthPrefix(ProtocolVersionV310Name); - writer.Write(0x03); // Protocol Level 3 + stream.WriteWithLengthPrefix(ProtocolVersionV310Name); + stream.WriteByte(0x03); // Protocol Level 3 } var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags @@ -366,68 +357,68 @@ namespace MQTTnet.Serializer connectFlags.Write(packet.Password != null); connectFlags.Write(packet.Username != null); - writer.Write(connectFlags); - writer.Write(packet.KeepAlivePeriod); - writer.WriteWithLengthPrefix(packet.ClientId); + stream.Write(connectFlags); + stream.Write(packet.KeepAlivePeriod); + stream.WriteWithLengthPrefix(packet.ClientId); if (packet.WillMessage != null) { - writer.WriteWithLengthPrefix(packet.WillMessage.Topic); - writer.WriteWithLengthPrefix(packet.WillMessage.Payload); + stream.WriteWithLengthPrefix(packet.WillMessage.Topic); + stream.WriteWithLengthPrefix(packet.WillMessage.Payload); } if (packet.Username != null) { - writer.WriteWithLengthPrefix(packet.Username); + stream.WriteWithLengthPrefix(packet.Username); } if (packet.Password != null) { - writer.WriteWithLengthPrefix(packet.Password); + stream.WriteWithLengthPrefix(packet.Password); } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - private byte Serialize(MqttConnAckPacket packet, MqttPacketWriter writer) + private byte Serialize(MqttConnAckPacket packet, Stream stream) { if (ProtocolVersion == MqttProtocolVersion.V310) { - writer.Write(0); + stream.WriteByte(0); } else if (ProtocolVersion == MqttProtocolVersion.V311) { var connectAcknowledgeFlags = new ByteWriter(); connectAcknowledgeFlags.Write(packet.IsSessionPresent); - writer.Write(connectAcknowledgeFlags); + stream.Write(connectAcknowledgeFlags); } else { throw new MqttProtocolViolationException("Protocol version not supported."); } - writer.Write((byte)packet.ConnectReturnCode); + stream.WriteByte((byte)packet.ConnectReturnCode); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); } - private static byte Serialize(MqttPubRelPacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttPubRelPacket packet, Stream stream) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubRel packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } - private static byte Serialize(MqttPublishPacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttPublishPacket packet, Stream stream) { ValidatePublishPacket(packet); - writer.WriteWithLengthPrefix(packet.Topic); + stream.WriteWithLengthPrefix(packet.Topic); if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { @@ -436,7 +427,7 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Publish packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); } else { @@ -448,7 +439,7 @@ namespace MQTTnet.Serializer if (packet.Payload?.Length > 0) { - writer.Write(packet.Payload); + stream.Write(packet.Payload, 0, packet.Payload.Length); } byte fixedHeader = 0; @@ -468,43 +459,43 @@ namespace MQTTnet.Serializer return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); } - private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttPubAckPacket packet, Stream stream) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubAck packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } - private static byte Serialize(MqttPubRecPacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttPubRecPacket packet, Stream stream) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubRec packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } - private static byte Serialize(MqttPubCompPacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttPubCompPacket packet, Stream stream) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("PubComp packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } - private static byte Serialize(MqttSubscribePacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttSubscribePacket packet, Stream stream) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); @@ -513,41 +504,41 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Subscribe packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Count > 0) { foreach (var topicFilter in packet.TopicFilters) { - writer.WriteWithLengthPrefix(topicFilter.Topic); - writer.Write((byte)topicFilter.QualityOfServiceLevel); + stream.WriteWithLengthPrefix(topicFilter.Topic); + stream.WriteByte((byte)topicFilter.QualityOfServiceLevel); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); } - private static byte Serialize(MqttSubAckPacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttSubAckPacket packet, Stream stream) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("SubAck packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); if (packet.SubscribeReturnCodes?.Any() == true) { foreach (var packetSubscribeReturnCode in packet.SubscribeReturnCodes) { - writer.Write((byte)packetSubscribeReturnCode); + stream.WriteByte((byte)packetSubscribeReturnCode); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); } - private static byte Serialize(MqttUnsubscribePacket packet, MqttPacketWriter writer) + private static byte Serialize(MqttUnsubscribePacket packet, Stream stream) { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); @@ -556,27 +547,27 @@ namespace MQTTnet.Serializer throw new MqttProtocolViolationException("Unsubscribe packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Any() == true) { foreach (var topicFilter in packet.TopicFilters) { - writer.WriteWithLengthPrefix(topicFilter); + stream.WriteWithLengthPrefix(topicFilter); } } return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static byte Serialize(MqttUnsubAckPacket packet, BinaryWriter writer) + private static byte Serialize(MqttUnsubAckPacket packet, Stream stream) { if (!packet.PacketIdentifier.HasValue) { throw new MqttProtocolViolationException("UnsubAck packet has no packet identifier."); } - writer.Write(packet.PacketIdentifier.Value); + stream.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs index 54f40eb..1d04018 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs @@ -5,13 +5,8 @@ using MQTTnet.Protocol; namespace MQTTnet.Serializer { - public sealed class MqttPacketWriter : BinaryWriter + public static class MqttPacketWriter { - public MqttPacketWriter(Stream stream) - : base(stream, Encoding.UTF8, true) - { - } - public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0) { var fixedHeader = (int)packetType << 4; @@ -19,40 +14,31 @@ namespace MQTTnet.Serializer return (byte)fixedHeader; } - public override void Write(ushort value) + public static void Write(this Stream stream, ushort value) { var buffer = BitConverter.GetBytes(value); - Write(buffer[1], buffer[0]); - } - - public new void Write(params byte[] values) - { - base.Write(values); - } - - public new void Write(byte value) - { - base.Write(value); + stream.WriteByte(buffer[1]); + stream.WriteByte(buffer[0]); } - public void Write(ByteWriter value) + public static void Write(this Stream stream, ByteWriter value) { if (value == null) throw new ArgumentNullException(nameof(value)); - Write(value.Value); + stream.WriteByte(value.Value); } - public void WriteWithLengthPrefix(string value) + public static void WriteWithLengthPrefix(this Stream stream, string value) { - WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); + stream.WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); } - public void WriteWithLengthPrefix(byte[] value) + public static void WriteWithLengthPrefix(this Stream stream, byte[] value) { var length = (ushort)value.Length; - Write(length); - Write(value); + stream.Write(length); + stream.Write(value, 0, length); } public static int EncodeRemainingLength(int length, MemoryStream stream) diff --git a/Frameworks/MQTTnet.NetStandard/Server/IMqttServerOptions.cs b/Frameworks/MQTTnet.NetStandard/Server/IMqttServerOptions.cs index cc58eeb..9ae7715 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/IMqttServerOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/IMqttServerOptions.cs @@ -5,6 +5,8 @@ namespace MQTTnet.Server public interface IMqttServerOptions { int ConnectionBacklog { get; } + int MaxPendingMessagesPerClient { get; } + TimeSpan DefaultCommunicationTimeout { get; } Action ConnectionValidator { get; } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs index 1d00528..fd89691 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs @@ -50,6 +50,17 @@ namespace MQTTnet.Server } } + public async Task DropPacket() + { + MqttBasePacket packet = null; + await _queueWaitSemaphore.WaitAsync().ConfigureAwait(false); + if (!_queue.TryDequeue(out packet)) + { + throw new InvalidOperationException(); // should not happen + } + _queueWaitSemaphore.Release(); + } + public void Enqueue(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs index 183ab8e..7d50bad 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs @@ -147,7 +147,10 @@ namespace MQTTnet.Server { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); } - + if (_options.MaxPendingMessagesPerClient <= PendingMessagesQueue.Count) + { + await PendingMessagesQueue.DropPacket(); + } PendingMessagesQueue.Enqueue(publishPacket); } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptions.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptions.cs index 2b0a9a0..4315fcf 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptions.cs @@ -9,6 +9,8 @@ namespace MQTTnet.Server public MqttServerTlsEndpointOptions TlsEndpointOptions { get; } = new MqttServerTlsEndpointOptions(); public int ConnectionBacklog { get; set; } = 10; + + public int MaxPendingMessagesPerClient { get; set; } = 250; public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15); diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptionsBuilder.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptionsBuilder.cs index 18b2de2..2c86512 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptionsBuilder.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttServerOptionsBuilder.cs @@ -13,6 +13,12 @@ namespace MQTTnet.Server return this; } + public MqttServerOptionsBuilder WithMaxPendingMessagesPerClient(int value) + { + _options.MaxPendingMessagesPerClient = value; + return this; + } + public MqttServerOptionsBuilder WithDefaultCommunicationTimeout(TimeSpan value) { _options.DefaultCommunicationTimeout = value; diff --git a/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj index 4977311..3b52bba 100644 --- a/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj +++ b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -150,6 +150,7 @@ + diff --git a/Tests/MQTTnet.Benchmarks/Program.cs b/Tests/MQTTnet.Benchmarks/Program.cs index 965ddb5..ad2c363 100644 --- a/Tests/MQTTnet.Benchmarks/Program.cs +++ b/Tests/MQTTnet.Benchmarks/Program.cs @@ -1,5 +1,7 @@ using System; +using System.Threading; using BenchmarkDotNet.Running; +using MQTTnet.Diagnostics; namespace MQTTnet.Benchmarks { @@ -7,7 +9,23 @@ namespace MQTTnet.Benchmarks { public static void Main(string[] args) { - var summary = BenchmarkRunner.Run(); + Console.WriteLine($"MQTTnet - BenchmarkApp.{TargetFrameworkInfoProvider.TargetFramework}"); + Console.WriteLine("1 = MessageProcessingBenchmark"); + Console.WriteLine("2 = SerializerBenchmark"); + + var pressedKey = Console.ReadKey(true); + switch (pressedKey.KeyChar) + { + case '1': + BenchmarkRunner.Run(); + break; + case '2': + BenchmarkRunner.Run(); + break; + default: + break; + } + Console.ReadLine(); } } diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs new file mode 100644 index 0000000..bcfff1c --- /dev/null +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -0,0 +1,74 @@ +using BenchmarkDotNet.Attributes; +using MQTTnet.Client; +using MQTTnet.Packets; +using MQTTnet.Serializer; +using MQTTnet.Internal; +using MQTTnet.Server; +using BenchmarkDotNet.Attributes.Jobs; +using BenchmarkDotNet.Attributes.Exporters; +using System; +using System.Threading; +using System.IO; +using MQTTnet.Core.Internal; + +namespace MQTTnet.Benchmarks +{ + [ClrJob] + [RPlotExporter] + [MemoryDiagnoser] + public class SerializerBenchmark + { + private MqttBasePacket _packet; + private ArraySegment _serializedPacket; + private MqttPacketSerializer _serializer; + + [GlobalSetup] + public void Setup() + { + var message = new MqttApplicationMessageBuilder() + .WithTopic("A") + .Build(); + + _packet = message.ToPublishPacket(); + _serializer = new MqttPacketSerializer(); + _serializedPacket = _serializer.Serialize(_packet); + } + + [Benchmark] + public void Serialize_10000_Messages() + { + for (var i = 0; i < 10000; i++) + { + _serializer.Serialize(_packet); + } + } + + [Benchmark] + public void Deserialize_10000_Messages() + { + for (var i = 0; i < 10000; i++) + { + using (var headerStream = new MemoryStream(Join(_serializedPacket))) + { + var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(headerStream), CancellationToken.None).GetAwaiter().GetResult(); + + using (var bodyStream = new MemoryStream(Join(_serializedPacket), (int)headerStream.Position, header.BodyLength)) + { + _serializer.Deserialize(header, bodyStream); + } + } + } + } + + private static byte[] Join(params ArraySegment[] chunks) + { + var buffer = new MemoryStream(); + foreach (var chunk in chunks) + { + buffer.Write(chunk.Array, chunk.Offset, chunk.Count); + } + + return buffer.ToArray(); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index daeea62..696cfa3 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests [TestMethod] public async Task TimeoutAfter() { - await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] public async Task TimeoutAfterWithResult() { - await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [TestMethod] public async Task TimeoutAfterCompleteInTime() { - var result = await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); + var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); Assert.AreEqual(5, result); } @@ -36,7 +36,7 @@ namespace MQTTnet.Core.Tests { try { - await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -55,7 +55,7 @@ namespace MQTTnet.Core.Tests { try { - await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => + await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; @@ -76,7 +76,7 @@ namespace MQTTnet.Core.Tests var tasks = Enumerable.Range(0, 100000) .Select(i => { - return Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); + return MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); }); await Task.WhenAll(tasks); diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index ca48ad9..72a675b 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -1,6 +1,7 @@ using System.IO; using System.Threading; using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Internal; using MQTTnet.Serializer; namespace MQTTnet.Core.Tests diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 580bc4d..a45736d 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -3,6 +3,7 @@ using System.IO; using System.Text; using System.Threading; using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Serializer;