From 7a9ddfc2b5287fb840fc08dc0ccaa099c00189d3 Mon Sep 17 00:00:00 2001 From: Christian Date: Sun, 25 Mar 2018 14:41:12 +0200 Subject: [PATCH] Fix packet identifier which is sent from the server. --- .../Adapter/MqttChannelAdapter.cs | 8 +-- .../Adapter/MqttChannelAdapterExtensions.cs | 17 ----- .../Packets/IMqttPacketWithIdentifier.cs | 2 +- .../Packets/MqttBasePublishPacket.cs | 2 +- .../Packets/MqttPacketExtensions.cs | 26 -------- .../Packets/MqttSubAckPacket.cs | 2 +- .../Packets/MqttSubscribePacket.cs | 2 +- .../Packets/MqttUnsubAckPacket.cs | 2 +- .../Packets/MqttUnsubscribe.cs | 2 +- .../Serializer/MqttPacketSerializer.cs | 65 ++++++++++++++++--- .../Server/MqttClientPendingMessagesQueue.cs | 2 +- .../Server/MqttClientSession.cs | 28 +++++--- .../Server/MqttClientSessionsManager.cs | 8 +-- .../Server/MqttClientSubscriptionsManager.cs | 13 +++- 14 files changed, 100 insertions(+), 79 deletions(-) delete mode 100644 Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapterExtensions.cs delete mode 100644 Frameworks/MQTTnet.NetStandard/Packets/MqttPacketExtensions.cs diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index b94cf61..0ad382a 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -138,7 +138,7 @@ namespace MQTTnet.Adapter } var body = header.BodyLength <= ReadBufferSize ? new MemoryStream(header.BodyLength) : new MemoryStream(); - + var buffer = new byte[ReadBufferSize]; while (body.Length < header.BodyLength) { @@ -149,7 +149,7 @@ namespace MQTTnet.Adapter } var readBytesCount = await stream.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); - + // Check if the client closed the connection before sending the full body. if (readBytesCount == 0) { @@ -162,7 +162,7 @@ namespace MQTTnet.Adapter } body.Seek(0L, SeekOrigin.Begin); - + return new ReceivedMqttPacket(header, body); } @@ -190,7 +190,7 @@ namespace MQTTnet.Adapter } catch (COMException comException) { - if ((uint) comException.HResult == ErrorOperationAborted) + if ((uint)comException.HResult == ErrorOperationAborted) { throw new OperationCanceledException(); } diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapterExtensions.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapterExtensions.cs deleted file mode 100644 index 58258ab..0000000 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapterExtensions.cs +++ /dev/null @@ -1,17 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Packets; - -namespace MQTTnet.Adapter -{ - public static class MqttChannelAdapterExtensions - { - public static Task SendPacketsAsync(this IMqttChannelAdapter adapter, TimeSpan timeout, CancellationToken cancellationToken, params MqttBasePacket[] packets) - { - if (adapter == null) throw new ArgumentNullException(nameof(adapter)); - - return adapter.SendPacketsAsync(timeout, cancellationToken, packets); - } - } -} \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Packets/IMqttPacketWithIdentifier.cs b/Frameworks/MQTTnet.NetStandard/Packets/IMqttPacketWithIdentifier.cs index 5f7f8e9..6bbce0c 100644 --- a/Frameworks/MQTTnet.NetStandard/Packets/IMqttPacketWithIdentifier.cs +++ b/Frameworks/MQTTnet.NetStandard/Packets/IMqttPacketWithIdentifier.cs @@ -2,6 +2,6 @@ { public interface IMqttPacketWithIdentifier { - ushort PacketIdentifier { get; set; } + ushort? PacketIdentifier { get; set; } } } diff --git a/Frameworks/MQTTnet.NetStandard/Packets/MqttBasePublishPacket.cs b/Frameworks/MQTTnet.NetStandard/Packets/MqttBasePublishPacket.cs index 6218c15..ffafc53 100644 --- a/Frameworks/MQTTnet.NetStandard/Packets/MqttBasePublishPacket.cs +++ b/Frameworks/MQTTnet.NetStandard/Packets/MqttBasePublishPacket.cs @@ -2,6 +2,6 @@ { public class MqttBasePublishPacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort PacketIdentifier { get; set; } + public ushort? PacketIdentifier { get; set; } } } diff --git a/Frameworks/MQTTnet.NetStandard/Packets/MqttPacketExtensions.cs b/Frameworks/MQTTnet.NetStandard/Packets/MqttPacketExtensions.cs deleted file mode 100644 index a3df53c..0000000 --- a/Frameworks/MQTTnet.NetStandard/Packets/MqttPacketExtensions.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System; - -namespace MQTTnet.Packets -{ - public static class MqttPacketExtensions - { - public static TResponsePacket CreateResponse(this MqttBasePacket packet) - { - if (packet == null) throw new ArgumentNullException(nameof(packet)); - - var responsePacket = Activator.CreateInstance(); - - if (responsePacket is IMqttPacketWithIdentifier responsePacketWithIdentifier) - { - if (!(packet is IMqttPacketWithIdentifier requestPacketWithIdentifier)) - { - throw new InvalidOperationException("Response packet has PacketIdentifier but request packet does not."); - } - - responsePacketWithIdentifier.PacketIdentifier = requestPacketWithIdentifier.PacketIdentifier; - } - - return responsePacket; - } - } -} diff --git a/Frameworks/MQTTnet.NetStandard/Packets/MqttSubAckPacket.cs b/Frameworks/MQTTnet.NetStandard/Packets/MqttSubAckPacket.cs index c4c1a88..3b8cf7a 100644 --- a/Frameworks/MQTTnet.NetStandard/Packets/MqttSubAckPacket.cs +++ b/Frameworks/MQTTnet.NetStandard/Packets/MqttSubAckPacket.cs @@ -6,7 +6,7 @@ namespace MQTTnet.Packets { public sealed class MqttSubAckPacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort PacketIdentifier { get; set; } + public ushort? PacketIdentifier { get; set; } public IList SubscribeReturnCodes { get; } = new List(); diff --git a/Frameworks/MQTTnet.NetStandard/Packets/MqttSubscribePacket.cs b/Frameworks/MQTTnet.NetStandard/Packets/MqttSubscribePacket.cs index 25536e3..63aede1 100644 --- a/Frameworks/MQTTnet.NetStandard/Packets/MqttSubscribePacket.cs +++ b/Frameworks/MQTTnet.NetStandard/Packets/MqttSubscribePacket.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Packets { public sealed class MqttSubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort PacketIdentifier { get; set; } + public ushort? PacketIdentifier { get; set; } public IList TopicFilters { get; set; } = new List(); diff --git a/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubAckPacket.cs b/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubAckPacket.cs index 4797d0a..688195f 100644 --- a/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubAckPacket.cs +++ b/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubAckPacket.cs @@ -2,7 +2,7 @@ { public sealed class MqttUnsubAckPacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort PacketIdentifier { get; set; } + public ushort? PacketIdentifier { get; set; } public override string ToString() { diff --git a/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs b/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs index 2c47cbb..6c24a47 100644 --- a/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs +++ b/Frameworks/MQTTnet.NetStandard/Packets/MqttUnsubscribe.cs @@ -4,7 +4,7 @@ namespace MQTTnet.Packets { public sealed class MqttUnsubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort PacketIdentifier { get; set; } + public ushort? PacketIdentifier { get; set; } public IList TopicFilters { get; set; } = new List(); diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 63a3c64..0904c75 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -397,7 +397,12 @@ namespace MQTTnet.Serializer private static byte Serialize(MqttPubRelPacket packet, MqttPacketWriter writer) { - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("PubRel packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } @@ -410,7 +415,12 @@ namespace MQTTnet.Serializer if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("Publish packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); } else { @@ -444,21 +454,36 @@ namespace MQTTnet.Serializer private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter writer) { - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("PubAck packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } private static byte Serialize(MqttPubRecPacket packet, MqttPacketWriter writer) { - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("PubRec packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } private static byte Serialize(MqttPubCompPacket packet, MqttPacketWriter writer) { - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("PubComp packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } @@ -467,7 +492,12 @@ namespace MQTTnet.Serializer { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("Subscribe packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Count > 0) { @@ -483,7 +513,12 @@ namespace MQTTnet.Serializer private static byte Serialize(MqttSubAckPacket packet, MqttPacketWriter writer) { - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("SubAck packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); if (packet.SubscribeReturnCodes?.Any() == true) { @@ -500,7 +535,12 @@ namespace MQTTnet.Serializer { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("Unsubscribe packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); if (packet.TopicFilters?.Any() == true) { @@ -513,9 +553,14 @@ namespace MQTTnet.Serializer return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static byte Serialize(IMqttPacketWithIdentifier packet, BinaryWriter writer) + private static byte Serialize(MqttUnsubAckPacket packet, BinaryWriter writer) { - writer.Write(packet.PacketIdentifier); + if (!packet.PacketIdentifier.HasValue) + { + throw new MqttProtocolViolationException("UnsubAck packet has no packet identifier."); + } + + writer.Write(packet.PacketIdentifier.Value); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs index 27dd49d..2469ede 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs @@ -78,7 +78,7 @@ namespace MQTTnet.Server throw new InvalidOperationException(); // should not happen } - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, packet).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { packet }).ConfigureAwait(false); _logger.Trace("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs index aee4351..62bb39f 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; +using MQTTnet.Client; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Internal; @@ -14,6 +15,7 @@ namespace MQTTnet.Server { public sealed class MqttClientSession : IDisposable { + private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly IMqttServerOptions _options; private readonly IMqttNetLogger _logger; private readonly MqttRetainedMessagesManager _retainedMessagesManager; @@ -129,6 +131,11 @@ namespace MQTTnet.Server var publishPacket = applicationMessage.ToPublishPacket(); publishPacket.QualityOfServiceLevel = result.QualityOfServiceLevel; + if (publishPacket.QualityOfServiceLevel > 0) + { + publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); + } + PendingMessagesQueue.Enqueue(publishPacket); } @@ -205,7 +212,7 @@ namespace MQTTnet.Server if (packet is MqttPingReqPacket) { - return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPingRespPacket()); + return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { new MqttPingRespPacket() }); } if (packet is MqttPubRelPacket pubRelPacket) @@ -215,7 +222,12 @@ namespace MQTTnet.Server if (packet is MqttPubRecPacket pubRecPacket) { - return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, pubRecPacket.CreateResponse()); + var responsePacket = new MqttPubRelPacket + { + PacketIdentifier = pubRecPacket.PacketIdentifier + }; + + return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { responsePacket }); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) @@ -246,11 +258,11 @@ namespace MQTTnet.Server private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { var subscribeResult = await SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, subscribeResult.ResponsePacket).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { subscribeResult.ResponsePacket }).ConfigureAwait(false); if (subscribeResult.CloseConnection) { - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttDisconnectPacket()).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { new MqttDisconnectPacket() }).ConfigureAwait(false); await StopAsync().ConfigureAwait(false); } @@ -260,7 +272,7 @@ namespace MQTTnet.Server private async Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { var unsubscribeResult = await SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, unsubscribeResult); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { unsubscribeResult }); } private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) @@ -302,7 +314,7 @@ namespace MQTTnet.Server await ApplicationMessageReceivedCallback(this, applicationMessage).ConfigureAwait(false); var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, response).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { response }).ConfigureAwait(false); } private async Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) @@ -311,13 +323,13 @@ namespace MQTTnet.Server await ApplicationMessageReceivedCallback(this, applicationMessage).ConfigureAwait(false); var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }; - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, response).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { response }).ConfigureAwait(false); } private Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }; - return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, response); + return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { response }); } } } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs index fe57552..f158089 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs @@ -53,10 +53,10 @@ namespace MQTTnet.Server var connectReturnCode = ValidateConnection(connectPacket); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttConnAckPacket + await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { new MqttConnAckPacket { ConnectReturnCode = connectReturnCode - }).ConfigureAwait(false); + }}).ConfigureAwait(false); return; } @@ -64,11 +64,11 @@ namespace MQTTnet.Server var result = await GetOrCreateClientSessionAsync(connectPacket).ConfigureAwait(false); clientSession = result.Session; - await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttConnAckPacket + await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new[] { new MqttConnAckPacket { ConnectReturnCode = connectReturnCode, IsSessionPresent = result.IsExistingSession - }).ConfigureAwait(false); + }}).ConfigureAwait(false); ClientConnectedCallback?.Invoke(new ConnectedMqttClient { diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs index 97b9f04..e4ee921 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs @@ -14,7 +14,7 @@ namespace MQTTnet.Server private readonly Dictionary _subscriptions = new Dictionary(); private readonly IMqttServerOptions _options; private readonly string _clientId; - + public MqttClientSubscriptionsManager(IMqttServerOptions options, string clientId) { _options = options ?? throw new ArgumentNullException(nameof(options)); @@ -30,7 +30,11 @@ namespace MQTTnet.Server var result = new MqttClientSubscribeResult { - ResponsePacket = subscribePacket.CreateResponse(), + ResponsePacket = new MqttSubAckPacket + { + PacketIdentifier = subscribePacket.PacketIdentifier + }, + CloseConnection = false }; @@ -87,7 +91,10 @@ namespace MQTTnet.Server _semaphore.Release(); } - return unsubscribePacket.CreateResponse(); + return new MqttUnsubAckPacket + { + PacketIdentifier = unsubscribePacket.PacketIdentifier + }; } public async Task CheckSubscriptionsAsync(MqttApplicationMessage applicationMessage)