diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index fcf329d..67d814e 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -66,10 +66,10 @@ namespace MQTTnet.AspNetCore public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } public long BytesSent { get; set; } + public long BytesReceived { get; set; } - public Action ReadingPacketStartedCallback { get; set; } - public Action ReadingPacketCompletedCallback { get; set; } + public bool IsReadingPacket { get; private set; } IHttpContextFeature Http => Connection.Features.Get(); @@ -128,7 +128,7 @@ namespace MQTTnet.AspNetCore else { // we did receive something but the message is not yet complete - ReadingPacketStartedCallback?.Invoke(); + IsReadingPacket = true; } } else if (readResult.IsCompleted) @@ -147,14 +147,14 @@ namespace MQTTnet.AspNetCore } catch (Exception e) { - // completing the cannels makes sure that there is no more data read after a protocol error + // completing the channel makes sure that there is no more data read after a protocol error _input?.Complete(e); _output?.Complete(e); throw; } finally { - ReadingPacketCompletedCallback?.Invoke(); + IsReadingPacket = false; } cancellationToken.ThrowIfCancellationRequested(); diff --git a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs index a41e8c1..78c77f1 100644 --- a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs @@ -21,9 +21,7 @@ namespace MQTTnet.Adapter long BytesReceived { get; } - Action ReadingPacketStartedCallback { get; set; } - - Action ReadingPacketCompletedCallback { get; set; } + bool IsReadingPacket { get; } Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken); diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 2efcbae..5cf7fc7 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -51,10 +51,10 @@ namespace MQTTnet.Adapter public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } public long BytesSent => Interlocked.Read(ref _bytesSent); + public long BytesReceived => Interlocked.Read(ref _bytesReceived); - public Action ReadingPacketStartedCallback { get; set; } - public Action ReadingPacketCompletedCallback { get; set; } + public bool IsReadingPacket { get; private set; } public async Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { @@ -79,7 +79,7 @@ namespace MQTTnet.Adapter throw; } - WrapException(exception); + WrapAndThrowException(exception); } } @@ -107,7 +107,7 @@ namespace MQTTnet.Adapter throw; } - WrapException(exception); + WrapAndThrowException(exception); } } @@ -148,7 +148,7 @@ namespace MQTTnet.Adapter throw; } - WrapException(exception); + WrapAndThrowException(exception); } finally { @@ -214,7 +214,7 @@ namespace MQTTnet.Adapter throw; } - WrapException(exception); + WrapAndThrowException(exception); } return null; @@ -253,7 +253,7 @@ namespace MQTTnet.Adapter return null; } - ReadingPacketStartedCallback?.Invoke(); + IsReadingPacket = true; var fixedHeader = readFixedHeaderResult.FixedHeader; if (fixedHeader.RemainingLength == 0) @@ -293,7 +293,7 @@ namespace MQTTnet.Adapter } finally { - ReadingPacketCompletedCallback?.Invoke(); + IsReadingPacket = false; } } @@ -304,7 +304,7 @@ namespace MQTTnet.Adapter exception is MqttCommunicationException; } - static void WrapException(Exception exception) + static void WrapAndThrowException(Exception exception) { if (exception is IOException && exception.InnerException is SocketException innerException) { @@ -313,11 +313,15 @@ namespace MQTTnet.Adapter if (exception is SocketException socketException) { - if (socketException.SocketErrorCode == SocketError.ConnectionAborted || - socketException.SocketErrorCode == SocketError.OperationAborted) + if (socketException.SocketErrorCode == SocketError.OperationAborted) { throw new OperationCanceledException(); } + + if (socketException.SocketErrorCode == SocketError.ConnectionAborted) + { + throw new MqttCommunicationException(socketException); + } } if (exception is COMException comException) diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index cb87120..5263c1e 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -395,9 +395,9 @@ namespace MQTTnet.Client cancellationToken.ThrowIfCancellationRequested(); ushort identifier = 0; - if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier.HasValue) + if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0) { - identifier = packetWithIdentifier.PacketIdentifier.Value; + identifier = packetWithIdentifier.PacketIdentifier; } using (var packetAwaiter = _packetDispatcher.AddAwaiter(identifier)) @@ -636,7 +636,8 @@ namespace MQTTnet.Client { try { - var publishPacketDequeueResult = await _publishPacketReceiverQueue.TryDequeueAsync(cancellationToken); + var publishPacketDequeueResult = + await _publishPacketReceiverQueue.TryDequeueAsync(cancellationToken); if (!publishPacketDequeueResult.IsSuccess) { return; @@ -677,6 +678,9 @@ namespace MQTTnet.Client throw new MqttProtocolViolationException("Received a not supported QoS level."); } } + catch (OperationCanceledException) + { + } catch (Exception exception) { _logger.Error(exception, "Error while handling application message."); diff --git a/Source/MQTTnet/Formatter/IMqttDataConverter.cs b/Source/MQTTnet/Formatter/IMqttDataConverter.cs index 79c9192..8c3ed59 100644 --- a/Source/MQTTnet/Formatter/IMqttDataConverter.cs +++ b/Source/MQTTnet/Formatter/IMqttDataConverter.cs @@ -1,10 +1,12 @@ -using MQTTnet.Client.Connecting; +using System.Collections.Generic; +using MQTTnet.Client.Connecting; using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; using MQTTnet.Client.Publishing; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; using MQTTnet.Packets; +using MQTTnet.Protocol; using MQTTnet.Server; using MqttClientSubscribeResult = MQTTnet.Client.Subscribing.MqttClientSubscribeResult; @@ -30,8 +32,12 @@ namespace MQTTnet.Formatter MqttSubscribePacket CreateSubscribePacket(MqttClientSubscribeOptions options); + MqttSubAckPacket CreateSubAckPacket(MqttSubscribePacket subscribePacket, Server.MqttClientSubscribeResult subscribeResult); + MqttUnsubscribePacket CreateUnsubscribePacket(MqttClientUnsubscribeOptions options); + MqttUnsubAckPacket CreateUnsubAckPacket(MqttUnsubscribePacket unsubscribePacket, List reasonCodes); + MqttDisconnectPacket CreateDisconnectPacket(MqttClientDisconnectOptions options); MqttClientPublishResult CreatePublishResult(MqttPubAckPacket pubAckPacket); diff --git a/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs b/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs index 87fb00b..7757acc 100644 --- a/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs +++ b/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs @@ -8,7 +8,7 @@ using MQTTnet.Packets; namespace MQTTnet.Formatter { - public class MqttPacketFormatterAdapter + public sealed class MqttPacketFormatterAdapter { IMqttPacketFormatter _formatter; @@ -77,17 +77,6 @@ namespace MQTTnet.Formatter UseProtocolVersion(protocolVersion); } - private void UseProtocolVersion(MqttProtocolVersion protocolVersion) - { - if (protocolVersion == MqttProtocolVersion.Unknown) - { - throw new InvalidOperationException("MQTT protocol version is invalid."); - } - - ProtocolVersion = protocolVersion; - _formatter = GetMqttPacketFormatter(protocolVersion, Writer); - } - public static IMqttPacketFormatter GetMqttPacketFormatter(MqttProtocolVersion protocolVersion, IMqttPacketWriter writer) { if (protocolVersion == MqttProtocolVersion.Unknown) @@ -116,7 +105,18 @@ namespace MQTTnet.Formatter } } - MqttProtocolVersion ParseProtocolVersion(ReceivedMqttPacket receivedMqttPacket) + void UseProtocolVersion(MqttProtocolVersion protocolVersion) + { + if (protocolVersion == MqttProtocolVersion.Unknown) + { + throw new InvalidOperationException("MQTT protocol version is invalid."); + } + + ProtocolVersion = protocolVersion; + _formatter = GetMqttPacketFormatter(protocolVersion, Writer); + } + + static MqttProtocolVersion ParseProtocolVersion(ReceivedMqttPacket receivedMqttPacket) { if (receivedMqttPacket == null) throw new ArgumentNullException(nameof(receivedMqttPacket)); diff --git a/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs b/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs index cf92ee6..2c55696 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs @@ -9,6 +9,7 @@ using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Server; using System; +using System.Collections.Generic; using System.Linq; using MqttClientSubscribeResult = MQTTnet.Client.Subscribing.MqttClientSubscribeResult; @@ -172,6 +173,18 @@ namespace MQTTnet.Formatter.V3 return subscribePacket; } + public MqttSubAckPacket CreateSubAckPacket(MqttSubscribePacket subscribePacket, Server.MqttClientSubscribeResult subscribeResult) + { + var subackPacket = new MqttSubAckPacket + { + PacketIdentifier = subscribePacket.PacketIdentifier + }; + + subackPacket.ReturnCodes.AddRange(subscribeResult.ReturnCodes); + + return subackPacket; + } + public MqttUnsubscribePacket CreateUnsubscribePacket(MqttClientUnsubscribeOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -182,6 +195,15 @@ namespace MQTTnet.Formatter.V3 return unsubscribePacket; } + public MqttUnsubAckPacket CreateUnsubAckPacket(MqttUnsubscribePacket unsubscribePacket, List reasonCodes) + { + return new MqttUnsubAckPacket + { + PacketIdentifier = unsubscribePacket.PacketIdentifier, + ReasonCodes = reasonCodes + }; + } + public MqttDisconnectPacket CreateDisconnectPacket(MqttClientDisconnectOptions options) { if (options.ReasonCode != MqttClientDisconnectReason.NormalDisconnection || options.ReasonString != null) diff --git a/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs b/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs index 8afa86c..43ffc1d 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs @@ -210,7 +210,7 @@ namespace MQTTnet.Formatter.V3 var topic = receivedMqttPacket.Body.ReadStringWithLengthPrefix(); - ushort? packetIdentifier = null; + ushort packetIdentifier = 0; if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { packetIdentifier = receivedMqttPacket.Body.ReadTwoByteInteger(); @@ -414,12 +414,12 @@ namespace MQTTnet.Formatter.V3 static byte EncodePubRelPacket(MqttPubRelPacket packet, IMqttPacketWriter packetWriter) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("PubRel packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } @@ -432,12 +432,12 @@ namespace MQTTnet.Formatter.V3 if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("Publish packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); } else { @@ -471,36 +471,36 @@ namespace MQTTnet.Formatter.V3 static byte EncodePubAckPacket(MqttPubAckPacket packet, IMqttPacketWriter packetWriter) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("PubAck packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } static byte EncodePubRecPacket(MqttPubRecPacket packet, IMqttPacketWriter packetWriter) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("PubRec packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } static byte EncodePubCompPacket(MqttPubCompPacket packet, IMqttPacketWriter packetWriter) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("PubComp packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } @@ -509,12 +509,12 @@ namespace MQTTnet.Formatter.V3 { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("Subscribe packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); if (packet.TopicFilters?.Count > 0) { @@ -530,12 +530,12 @@ namespace MQTTnet.Formatter.V3 static byte EncodeSubAckPacket(MqttSubAckPacket packet, IMqttPacketWriter packetWriter) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("SubAck packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); if (packet.ReturnCodes?.Any() == true) { @@ -552,12 +552,12 @@ namespace MQTTnet.Formatter.V3 { if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("Unsubscribe packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); if (packet.TopicFilters?.Any() == true) { @@ -572,12 +572,12 @@ namespace MQTTnet.Formatter.V3 static byte EncodeUnsubAckPacket(MqttUnsubAckPacket packet, IMqttPacketWriter packetWriter) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("UnsubAck packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } diff --git a/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs b/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs index 7b37866..09b0d0e 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV311PacketFormatter.cs @@ -4,7 +4,7 @@ using MQTTnet.Protocol; namespace MQTTnet.Formatter.V3 { - public class MqttV311PacketFormatter : MqttV310PacketFormatter + public sealed class MqttV311PacketFormatter : MqttV310PacketFormatter { public MqttV311PacketFormatter(IMqttPacketWriter packetWriter) : base(packetWriter) diff --git a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs index f55c347..ae2aeed 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs @@ -200,6 +200,18 @@ namespace MQTTnet.Formatter.V5 return packet; } + public MqttSubAckPacket CreateSubAckPacket(MqttSubscribePacket subscribePacket, Server.MqttClientSubscribeResult subscribeResult) + { + var subackPacket = new MqttSubAckPacket + { + PacketIdentifier = subscribePacket.PacketIdentifier + }; + + subackPacket.ReasonCodes.AddRange(subscribeResult.ReasonCodes); + + return subackPacket; + } + public MqttUnsubscribePacket CreateUnsubscribePacket(MqttClientUnsubscribeOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -215,6 +227,15 @@ namespace MQTTnet.Formatter.V5 return packet; } + public MqttUnsubAckPacket CreateUnsubAckPacket(MqttUnsubscribePacket unsubscribePacket, List reasonCodes) + { + return new MqttUnsubAckPacket + { + PacketIdentifier = unsubscribePacket.PacketIdentifier, + ReasonCodes = reasonCodes + }; + } + public MqttDisconnectPacket CreateDisconnectPacket(MqttClientDisconnectOptions options) { var packet = new MqttDisconnectPacket(); diff --git a/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs b/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs index 9ebc740..bf978c4 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs @@ -6,14 +6,13 @@ using MQTTnet.Protocol; namespace MQTTnet.Formatter.V5 { - public class MqttV500PacketEncoder + public sealed class MqttV500PacketEncoder { - private readonly IMqttPacketWriter _packetWriter; + readonly IMqttPacketWriter _packetWriter; public MqttV500PacketEncoder() : this(new MqttPacketWriter()) { - } public MqttV500PacketEncoder(IMqttPacketWriter packetWriter) @@ -21,7 +20,6 @@ namespace MQTTnet.Formatter.V5 _packetWriter = packetWriter; } - public ArraySegment Encode(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); @@ -52,7 +50,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.FreeBuffer(); } - private static byte EncodePacket(MqttBasePacket packet, IMqttPacketWriter packetWriter) + static byte EncodePacket(MqttBasePacket packet, IMqttPacketWriter packetWriter) { switch (packet) { @@ -76,7 +74,7 @@ namespace MQTTnet.Formatter.V5 } } - private static byte EncodeConnectPacket(MqttConnectPacket packet, IMqttPacketWriter packetWriter) + static byte EncodeConnectPacket(MqttConnectPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); @@ -176,7 +174,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); } - private static byte EncodeConnAckPacket(MqttConnAckPacket packet, IMqttPacketWriter packetWriter) + static byte EncodeConnAckPacket(MqttConnAckPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); @@ -221,7 +219,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); } - private static byte EncodePublishPacket(MqttPublishPacket packet, IMqttPacketWriter packetWriter) + static byte EncodePublishPacket(MqttPublishPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); @@ -235,12 +233,12 @@ namespace MQTTnet.Formatter.V5 if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("Publish packet has no packet identifier."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); } else { @@ -287,12 +285,12 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); } - private static byte EncodePubAckPacket(MqttPubAckPacket packet, IMqttPacketWriter packetWriter) + static byte EncodePubAckPacket(MqttPubAckPacket packet, IMqttPacketWriter packetWriter) { if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packetWriter == null) throw new ArgumentNullException(nameof(packetWriter)); - if (!packet.PacketIdentifier.HasValue) + if (packet.PacketIdentifier == 0) { throw new MqttProtocolViolationException("PubAck packet has no packet identifier."); } @@ -302,7 +300,7 @@ namespace MQTTnet.Formatter.V5 throw new MqttProtocolViolationException("PubAck packet must contain a reason code."); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); var propertiesWriter = new MqttV500PropertiesWriter(); if (packet.Properties != null) @@ -320,9 +318,9 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); } - private static byte EncodePubRecPacket(MqttPubRecPacket packet, IMqttPacketWriter packetWriter) + static byte EncodePubRecPacket(MqttPubRecPacket packet, IMqttPacketWriter packetWriter) { - ThrowIfPacketIdentifierIsInvalid(packet); + ThrowIfPacketIdentifierIsInvalid(packet.PacketIdentifier, packet); if (!packet.ReasonCode.HasValue) { @@ -336,7 +334,7 @@ namespace MQTTnet.Formatter.V5 propertiesWriter.WriteUserProperties(packet.Properties.UserProperties); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); if (packetWriter.Length > 0 || packet.ReasonCode.Value != MqttPubRecReasonCode.Success) { @@ -347,9 +345,9 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); } - private static byte EncodePubRelPacket(MqttPubRelPacket packet, IMqttPacketWriter packetWriter) + static byte EncodePubRelPacket(MqttPubRelPacket packet, IMqttPacketWriter packetWriter) { - ThrowIfPacketIdentifierIsInvalid(packet); + ThrowIfPacketIdentifierIsInvalid(packet.PacketIdentifier, packet); if (!packet.ReasonCode.HasValue) { @@ -363,7 +361,7 @@ namespace MQTTnet.Formatter.V5 propertiesWriter.WriteUserProperties(packet.Properties.UserProperties); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); if (propertiesWriter.Length > 0 || packet.ReasonCode.Value != MqttPubRelReasonCode.Success) { @@ -374,16 +372,16 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); } - private static byte EncodePubCompPacket(MqttPubCompPacket packet, IMqttPacketWriter packetWriter) + static byte EncodePubCompPacket(MqttPubCompPacket packet, IMqttPacketWriter packetWriter) { - ThrowIfPacketIdentifierIsInvalid(packet); + ThrowIfPacketIdentifierIsInvalid(packet.PacketIdentifier, packet); if (!packet.ReasonCode.HasValue) { ThrowReasonCodeNotSetException(); } - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); var propertiesWriter = new MqttV500PropertiesWriter(); if (packet.Properties != null) @@ -401,13 +399,13 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); } - private static byte EncodeSubscribePacket(MqttSubscribePacket packet, IMqttPacketWriter packetWriter) + static byte EncodeSubscribePacket(MqttSubscribePacket packet, IMqttPacketWriter packetWriter) { if (packet.TopicFilters?.Any() != true) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); - ThrowIfPacketIdentifierIsInvalid(packet); + ThrowIfPacketIdentifierIsInvalid(packet.PacketIdentifier, packet); - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); var propertiesWriter = new MqttV500PropertiesWriter(); if (packet.Properties != null) @@ -448,13 +446,13 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); } - private static byte EncodeSubAckPacket(MqttSubAckPacket packet, IMqttPacketWriter packetWriter) + static byte EncodeSubAckPacket(MqttSubAckPacket packet, IMqttPacketWriter packetWriter) { if (packet.ReasonCodes?.Any() != true) throw new MqttProtocolViolationException("At least one reason code must be set[MQTT - 3.8.3 - 3]."); - ThrowIfPacketIdentifierIsInvalid(packet); + ThrowIfPacketIdentifierIsInvalid(packet.PacketIdentifier, packet); - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); var propertiesWriter = new MqttV500PropertiesWriter(); if (packet.Properties != null) @@ -473,13 +471,13 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); } - private static byte EncodeUnsubscribePacket(MqttUnsubscribePacket packet, IMqttPacketWriter packetWriter) + static byte EncodeUnsubscribePacket(MqttUnsubscribePacket packet, IMqttPacketWriter packetWriter) { if (packet.TopicFilters?.Any() != true) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); - ThrowIfPacketIdentifierIsInvalid(packet); + ThrowIfPacketIdentifierIsInvalid(packet.PacketIdentifier, packet); - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); var propertiesWriter = new MqttV500PropertiesWriter(); if (packet.Properties != null) @@ -497,13 +495,13 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); } - private static byte EncodeUnsubAckPacket(MqttUnsubAckPacket packet, IMqttPacketWriter packetWriter) + static byte EncodeUnsubAckPacket(MqttUnsubAckPacket packet, IMqttPacketWriter packetWriter) { if (packet.ReasonCodes?.Any() != true) throw new MqttProtocolViolationException("At least one reason code must be set[MQTT - 3.8.3 - 3]."); - ThrowIfPacketIdentifierIsInvalid(packet); + ThrowIfPacketIdentifierIsInvalid(packet.PacketIdentifier, packet); - packetWriter.Write(packet.PacketIdentifier.Value); + packetWriter.Write(packet.PacketIdentifier); var propertiesWriter = new MqttV500PropertiesWriter(); if (packet.Properties != null) @@ -522,7 +520,7 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); } - private static byte EncodeDisconnectPacket(MqttDisconnectPacket packet, IMqttPacketWriter packetWriter) + static byte EncodeDisconnectPacket(MqttDisconnectPacket packet, IMqttPacketWriter packetWriter) { if (!packet.ReasonCode.HasValue) { @@ -545,17 +543,17 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Disconnect); } - private static byte EncodePingReqPacket() + static byte EncodePingReqPacket() { return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PingReq); } - private static byte EncodePingRespPacket() + static byte EncodePingRespPacket() { return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PingResp); } - private static byte EncodeAuthPacket(MqttAuthPacket packet, IMqttPacketWriter packetWriter) + static byte EncodeAuthPacket(MqttAuthPacket packet, IMqttPacketWriter packetWriter) { packetWriter.Write((byte)packet.ReasonCode); @@ -573,14 +571,16 @@ namespace MQTTnet.Formatter.V5 return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Auth); } - private static void ThrowReasonCodeNotSetException() + static void ThrowReasonCodeNotSetException() { throw new MqttProtocolViolationException("The ReasonCode must be set for MQTT version 5."); } - private static void ThrowIfPacketIdentifierIsInvalid(IMqttPacketWithIdentifier packet) + static void ThrowIfPacketIdentifierIsInvalid(ushort packetIdentifier, MqttBasePacket packet) { - if (!packet.PacketIdentifier.HasValue) + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH(in cases where QoS > 0) Control Packets MUST contain a non-zero 16 - bit Packet Identifier[MQTT - 2.3.1 - 1]. + + if (packetIdentifier == 0) { throw new MqttProtocolViolationException($"Packet identifier is not set for {packet.GetType().Name}."); } diff --git a/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs b/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs index f632586..28931c8 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500PropertiesWriter.cs @@ -5,10 +5,9 @@ using MQTTnet.Protocol; namespace MQTTnet.Formatter.V5 { - public class MqttV500PropertiesWriter + public sealed class MqttV500PropertiesWriter { - // TODO: Consider lazy init on first write to avoid useless allocations. - private readonly MqttPacketWriter _packetWriter = new MqttPacketWriter(); + readonly MqttPacketWriter _packetWriter = new MqttPacketWriter(); public int Length => _packetWriter.Length; @@ -178,7 +177,7 @@ namespace MQTTnet.Formatter.V5 Write(MqttPropertyId.ResponseInformation, value); } - private void Write(MqttPropertyId id, bool? value) + void Write(MqttPropertyId id, bool? value) { if (!value.HasValue) { @@ -189,7 +188,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.Write(value.Value ? (byte)0x1 : (byte)0x0); } - private void Write(MqttPropertyId id, byte? value) + void Write(MqttPropertyId id, byte? value) { if (!value.HasValue) { @@ -200,7 +199,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.Write(value.Value); } - private void Write(MqttPropertyId id, ushort? value) + void Write(MqttPropertyId id, ushort? value) { if (!value.HasValue) { @@ -211,7 +210,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.Write(value.Value); } - private void WriteAsVariableLengthInteger(MqttPropertyId id, uint? value) + void WriteAsVariableLengthInteger(MqttPropertyId id, uint? value) { if (!value.HasValue) { @@ -222,7 +221,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.WriteVariableLengthInteger(value.Value); } - private void WriteAsFourByteInteger(MqttPropertyId id, uint? value) + void WriteAsFourByteInteger(MqttPropertyId id, uint? value) { if (!value.HasValue) { @@ -236,7 +235,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.Write((byte)value.Value); } - private void Write(MqttPropertyId id, string value) + void Write(MqttPropertyId id, string value) { if (value == null) { @@ -247,7 +246,7 @@ namespace MQTTnet.Formatter.V5 _packetWriter.WriteWithLengthPrefix(value); } - private void Write(MqttPropertyId id, byte[] value) + void Write(MqttPropertyId id, byte[] value) { if (value == null) { diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 5d96740..4f0a2d0 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -10,6 +10,7 @@ using System.Runtime.ExceptionServices; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; +using MQTTnet.Exceptions; namespace MQTTnet.Implementations { @@ -75,7 +76,7 @@ namespace MQTTnet.Implementations cancellationToken.ThrowIfCancellationRequested(); var networkStream = socket.GetStream(); - + if (_options.TlsOptions?.UseTls == true) { var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); @@ -144,7 +145,7 @@ namespace MQTTnet.Implementations if (stream == null) { - throw new ObjectDisposedException(nameof(stream)); + return 0; } return await stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); @@ -185,6 +186,10 @@ namespace MQTTnet.Implementations } await stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + + // This subsequent call is required to check whether the socket is still connected. + // Without this call a broken connection is only recognized at the next call. + await stream.WriteAsync(PlatformAbstractionLayer.EmptyByteArray, 0, 0, cancellationToken); } } catch (ObjectDisposedException) diff --git a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs index 5de1b41..1900ff4 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs @@ -47,7 +47,7 @@ namespace MQTTnet.Implementations { throw new ArgumentException("TLS certificate is not set."); } - + var tlsCertificate = options.TlsEndpointOptions.CertificateProvider.GetCertificate(); if (!tlsCertificate.HasPrivateKey) { @@ -73,16 +73,22 @@ namespace MQTTnet.Implementations void Cleanup() { - _cancellationTokenSource?.Cancel(false); - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; - - foreach (var listener in _listeners) + try { - listener.Dispose(); + _cancellationTokenSource?.Cancel(false); } + finally + { + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; - _listeners.Clear(); + foreach (var listener in _listeners) + { + listener.Dispose(); + } + + _listeners.Clear(); + } } void RegisterListeners(MqttServerTcpEndpointBaseOptions options, X509Certificate2 tlsCertificate, CancellationToken cancellationToken) diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index 8e54c78..0955139 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -77,7 +77,7 @@ namespace MQTTnet.Implementations _socket.Bind(_localEndPoint); _socket.Listen(_options.ConnectionBacklog); - + Task.Run(() => AcceptClientConnectionsAsync(cancellationToken), cancellationToken).Forget(_logger); return true; @@ -152,9 +152,7 @@ namespace MQTTnet.Implementations _addressFamily == AddressFamily.InterNetwork ? "ipv4" : "ipv6"); clientSocket.NoDelay = _options.NoDelay; - stream = clientSocket.GetStream(); - X509Certificate2 clientCertificate = null; if (_tlsCertificate != null) @@ -208,17 +206,17 @@ namespace MQTTnet.Implementations { stream?.Dispose(); clientSocket?.Dispose(); - - _logger.Verbose("Client '{0}' disconnected at TCP listener '{1}, {2}'.", - remoteEndPoint, - _localEndPoint, - _addressFamily == AddressFamily.InterNetwork ? "ipv4" : "ipv6"); } catch (Exception disposeException) { _logger.Error(disposeException, "Error while cleaning up client connection"); } } + + _logger.Verbose("Client '{0}' disconnected at TCP listener '{1}, {2}'.", + remoteEndPoint, + _localEndPoint, + _addressFamily == AddressFamily.InterNetwork ? "ipv4" : "ipv6"); } } } diff --git a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs index 0b683dc..735c64d 100644 --- a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs +++ b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs @@ -1,20 +1,25 @@ -using System.Threading.Tasks; +using System; +using System.Threading.Tasks; namespace MQTTnet.Implementations { public static class PlatformAbstractionLayer { - public static Task CompletedTask + public static readonly byte[] EmptyByteArray = new byte[0]; + +#if NET452 + public static Task CompletedTask => Task.FromResult(0); +#else + public static Task CompletedTask => Task.CompletedTask; +#endif + + public static void Sleep(TimeSpan timeout) { - get - { -#if NET452 - return Task.FromResult(0); +#if NET452 || NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 + System.Threading.Thread.Sleep(timeout); #else - return Task.CompletedTask; + Task.Delay(timeout).Wait(); #endif - } } - } } diff --git a/Source/MQTTnet/Internal/Disposable.cs b/Source/MQTTnet/Internal/Disposable.cs index e9b05ea..be733f2 100644 --- a/Source/MQTTnet/Internal/Disposable.cs +++ b/Source/MQTTnet/Internal/Disposable.cs @@ -4,7 +4,7 @@ namespace MQTTnet.Internal { public abstract class Disposable : IDisposable { - protected bool IsDisposed { get; private set; } = false; + protected bool IsDisposed { get; private set; } protected void ThrowIfDisposed() { @@ -18,13 +18,6 @@ namespace MQTTnet.Internal { } - // TODO: override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. - // ~Disposable() - // { - // // Do not change this code. Put cleanup code in Dispose(bool disposing) above. - // Dispose(false); - // } - // This code added to correctly implement the disposable pattern. public void Dispose() { diff --git a/Source/MQTTnet/Internal/TestMqttChannel.cs b/Source/MQTTnet/Internal/TestMqttChannel.cs index 954aa1b..fba1269 100644 --- a/Source/MQTTnet/Internal/TestMqttChannel.cs +++ b/Source/MQTTnet/Internal/TestMqttChannel.cs @@ -8,7 +8,7 @@ namespace MQTTnet.Internal { public class TestMqttChannel : IMqttChannel { - private readonly MemoryStream _stream; + readonly MemoryStream _stream; public TestMqttChannel(MemoryStream stream) { diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs index 0ae2315..6935e74 100644 --- a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs @@ -27,10 +27,11 @@ namespace MQTTnet.PacketDispatcher { using (var timeoutToken = new CancellationTokenSource(timeout)) { - timeoutToken.Token.Register(() => Fail(new MqttCommunicationTimedOutException())); - - var packet = await _taskCompletionSource.Task.ConfigureAwait(false); - return (TPacket)packet; + using (timeoutToken.Token.Register(() => Fail(new MqttCommunicationTimedOutException()))) + { + var packet = await _taskCompletionSource.Task.ConfigureAwait(false); + return (TPacket)packet; + } } } diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs index 785fd35..ad873fb 100644 --- a/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs @@ -34,9 +34,9 @@ namespace MQTTnet.PacketDispatcher } ushort identifier = 0; - if (packet is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier.HasValue) + if (packet is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0) { - identifier = packetWithIdentifier.PacketIdentifier.Value; + identifier = packetWithIdentifier.PacketIdentifier; } var type = packet.GetType(); diff --git a/Source/MQTTnet/Packets/IMqttPacketWithIdentifier.cs b/Source/MQTTnet/Packets/IMqttPacketWithIdentifier.cs index 6bbce0c..5f7f8e9 100644 --- a/Source/MQTTnet/Packets/IMqttPacketWithIdentifier.cs +++ b/Source/MQTTnet/Packets/IMqttPacketWithIdentifier.cs @@ -2,6 +2,6 @@ { public interface IMqttPacketWithIdentifier { - ushort? PacketIdentifier { get; set; } + ushort PacketIdentifier { get; set; } } } diff --git a/Source/MQTTnet/Packets/MqttAuthPacket.cs b/Source/MQTTnet/Packets/MqttAuthPacket.cs index cf93a00..a792570 100644 --- a/Source/MQTTnet/Packets/MqttAuthPacket.cs +++ b/Source/MQTTnet/Packets/MqttAuthPacket.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Packets /// /// Added in MQTTv5.0.0. /// - public class MqttAuthPacket : MqttBasePacket + public sealed class MqttAuthPacket : MqttBasePacket { public MqttAuthenticateReasonCode ReasonCode { get; set; } diff --git a/Source/MQTTnet/Packets/MqttAuthPacketProperties.cs b/Source/MQTTnet/Packets/MqttAuthPacketProperties.cs index 057004d..d4910af 100644 --- a/Source/MQTTnet/Packets/MqttAuthPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttAuthPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttAuthPacketProperties + public sealed class MqttAuthPacketProperties { public string AuthenticationMethod { get; set; } diff --git a/Source/MQTTnet/Packets/MqttBasePublishPacket.cs b/Source/MQTTnet/Packets/MqttBasePublishPacket.cs deleted file mode 100644 index ffafc53..0000000 --- a/Source/MQTTnet/Packets/MqttBasePublishPacket.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace MQTTnet.Packets -{ - public class MqttBasePublishPacket : MqttBasePacket, IMqttPacketWithIdentifier - { - public ushort? PacketIdentifier { get; set; } - } -} diff --git a/Source/MQTTnet/Packets/MqttConnAckPacket.cs b/Source/MQTTnet/Packets/MqttConnAckPacket.cs index 16a492c..6eceab7 100644 --- a/Source/MQTTnet/Packets/MqttConnAckPacket.cs +++ b/Source/MQTTnet/Packets/MqttConnAckPacket.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttConnAckPacket : MqttBasePacket + public sealed class MqttConnAckPacket : MqttBasePacket { public MqttConnectReturnCode? ReturnCode { get; set; } diff --git a/Source/MQTTnet/Packets/MqttConnAckPacketProperties.cs b/Source/MQTTnet/Packets/MqttConnAckPacketProperties.cs index d59735c..4cd003e 100644 --- a/Source/MQTTnet/Packets/MqttConnAckPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttConnAckPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttConnAckPacketProperties + public sealed class MqttConnAckPacketProperties { public uint? SessionExpiryInterval { get; set; } diff --git a/Source/MQTTnet/Packets/MqttConnectPacket.cs b/Source/MQTTnet/Packets/MqttConnectPacket.cs index 4eae2d6..6130262 100644 --- a/Source/MQTTnet/Packets/MqttConnectPacket.cs +++ b/Source/MQTTnet/Packets/MqttConnectPacket.cs @@ -1,6 +1,6 @@ namespace MQTTnet.Packets { - public class MqttConnectPacket : MqttBasePacket + public sealed class MqttConnectPacket : MqttBasePacket { public string ClientId { get; set; } diff --git a/Source/MQTTnet/Packets/MqttConnectPacketProperties.cs b/Source/MQTTnet/Packets/MqttConnectPacketProperties.cs index 4cc405e..b9bd81e 100644 --- a/Source/MQTTnet/Packets/MqttConnectPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttConnectPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttConnectPacketProperties + public sealed class MqttConnectPacketProperties { public uint? WillDelayInterval { get; set; } diff --git a/Source/MQTTnet/Packets/MqttDisconnectPacket.cs b/Source/MQTTnet/Packets/MqttDisconnectPacket.cs index c198d9a..d28e57b 100644 --- a/Source/MQTTnet/Packets/MqttDisconnectPacket.cs +++ b/Source/MQTTnet/Packets/MqttDisconnectPacket.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttDisconnectPacket : MqttBasePacket + public sealed class MqttDisconnectPacket : MqttBasePacket { #region Added in MQTTv5 diff --git a/Source/MQTTnet/Packets/MqttDisconnectPacketProperties.cs b/Source/MQTTnet/Packets/MqttDisconnectPacketProperties.cs index e05529f..c4f8f79 100644 --- a/Source/MQTTnet/Packets/MqttDisconnectPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttDisconnectPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttDisconnectPacketProperties + public sealed class MqttDisconnectPacketProperties { public uint? SessionExpiryInterval { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPingReqPacket.cs b/Source/MQTTnet/Packets/MqttPingReqPacket.cs index 5c65da1..459e8b9 100644 --- a/Source/MQTTnet/Packets/MqttPingReqPacket.cs +++ b/Source/MQTTnet/Packets/MqttPingReqPacket.cs @@ -1,7 +1,10 @@ namespace MQTTnet.Packets { - public class MqttPingReqPacket : MqttBasePacket + public sealed class MqttPingReqPacket : MqttBasePacket { + // This is a minor performance improvement. + public static MqttPingReqPacket Instance = new MqttPingReqPacket(); + public override string ToString() { return "PingReq"; diff --git a/Source/MQTTnet/Packets/MqttPingRespPacket.cs b/Source/MQTTnet/Packets/MqttPingRespPacket.cs index b6f9bc3..de1db5e 100644 --- a/Source/MQTTnet/Packets/MqttPingRespPacket.cs +++ b/Source/MQTTnet/Packets/MqttPingRespPacket.cs @@ -1,7 +1,10 @@ namespace MQTTnet.Packets { - public class MqttPingRespPacket : MqttBasePacket + public sealed class MqttPingRespPacket : MqttBasePacket { + // This is a minor performance improvement. + public static MqttPingRespPacket Instance = new MqttPingRespPacket(); + public override string ToString() { return "PingResp"; diff --git a/Source/MQTTnet/Packets/MqttPubAckPacket.cs b/Source/MQTTnet/Packets/MqttPubAckPacket.cs index c919025..dbd8d71 100644 --- a/Source/MQTTnet/Packets/MqttPubAckPacket.cs +++ b/Source/MQTTnet/Packets/MqttPubAckPacket.cs @@ -2,8 +2,10 @@ namespace MQTTnet.Packets { - public class MqttPubAckPacket : MqttBasePublishPacket + public sealed class MqttPubAckPacket : MqttBasePacket, IMqttPacketWithIdentifier { + public ushort PacketIdentifier { get; set; } + #region Added in MQTTv5 public MqttPubAckReasonCode? ReasonCode { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPubAckPacketProperties.cs b/Source/MQTTnet/Packets/MqttPubAckPacketProperties.cs index e2debd4..58cb057 100644 --- a/Source/MQTTnet/Packets/MqttPubAckPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttPubAckPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttPubAckPacketProperties + public sealed class MqttPubAckPacketProperties { public string ReasonString { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPubCompPacket.cs b/Source/MQTTnet/Packets/MqttPubCompPacket.cs index 0798819..573b0eb 100644 --- a/Source/MQTTnet/Packets/MqttPubCompPacket.cs +++ b/Source/MQTTnet/Packets/MqttPubCompPacket.cs @@ -2,8 +2,10 @@ namespace MQTTnet.Packets { - public class MqttPubCompPacket : MqttBasePublishPacket + public sealed class MqttPubCompPacket : MqttBasePacket, IMqttPacketWithIdentifier { + public ushort PacketIdentifier { get; set; } + #region Added in MQTTv5 public MqttPubCompReasonCode? ReasonCode { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPubCompPacketProperties.cs b/Source/MQTTnet/Packets/MqttPubCompPacketProperties.cs index 35e040f..9d2e1a7 100644 --- a/Source/MQTTnet/Packets/MqttPubCompPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttPubCompPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttPubCompPacketProperties + public sealed class MqttPubCompPacketProperties { public string ReasonString { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPubRecPacket.cs b/Source/MQTTnet/Packets/MqttPubRecPacket.cs index e462ce8..7e704be 100644 --- a/Source/MQTTnet/Packets/MqttPubRecPacket.cs +++ b/Source/MQTTnet/Packets/MqttPubRecPacket.cs @@ -2,8 +2,10 @@ namespace MQTTnet.Packets { - public class MqttPubRecPacket : MqttBasePublishPacket + public sealed class MqttPubRecPacket : MqttBasePacket, IMqttPacketWithIdentifier { + public ushort PacketIdentifier { get; set; } + #region Added in MQTTv5 public MqttPubRecReasonCode? ReasonCode { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPubRecPacketProperties.cs b/Source/MQTTnet/Packets/MqttPubRecPacketProperties.cs index 0cd7225..2d302e8 100644 --- a/Source/MQTTnet/Packets/MqttPubRecPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttPubRecPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttPubRecPacketProperties + public sealed class MqttPubRecPacketProperties { public string ReasonString { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPubRelPacket.cs b/Source/MQTTnet/Packets/MqttPubRelPacket.cs index cf2c71b..810c3ba 100644 --- a/Source/MQTTnet/Packets/MqttPubRelPacket.cs +++ b/Source/MQTTnet/Packets/MqttPubRelPacket.cs @@ -2,8 +2,10 @@ namespace MQTTnet.Packets { - public class MqttPubRelPacket : MqttBasePublishPacket + public sealed class MqttPubRelPacket : MqttBasePacket, IMqttPacketWithIdentifier { + public ushort PacketIdentifier { get; set; } + #region Added in MQTTv5 public MqttPubRelReasonCode? ReasonCode { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPubRelPacketProperties.cs b/Source/MQTTnet/Packets/MqttPubRelPacketProperties.cs index aa9625d..9cd610b 100644 --- a/Source/MQTTnet/Packets/MqttPubRelPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttPubRelPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttPubRelPacketProperties + public sealed class MqttPubRelPacketProperties { public string ReasonString { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPublishPacket.cs b/Source/MQTTnet/Packets/MqttPublishPacket.cs index 2f20b14..9fabdf9 100644 --- a/Source/MQTTnet/Packets/MqttPublishPacket.cs +++ b/Source/MQTTnet/Packets/MqttPublishPacket.cs @@ -2,8 +2,10 @@ namespace MQTTnet.Packets { - public class MqttPublishPacket : MqttBasePublishPacket + public sealed class MqttPublishPacket : MqttBasePacket, IMqttPacketWithIdentifier { + public ushort PacketIdentifier { get; set; } + public bool Retain { get; set; } public MqttQualityOfServiceLevel QualityOfServiceLevel { get; set; } diff --git a/Source/MQTTnet/Packets/MqttPublishPacketProperties.cs b/Source/MQTTnet/Packets/MqttPublishPacketProperties.cs index ba2d8c6..076edf5 100644 --- a/Source/MQTTnet/Packets/MqttPublishPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttPublishPacketProperties.cs @@ -3,7 +3,7 @@ using MQTTnet.Protocol; namespace MQTTnet.Packets { - public class MqttPublishPacketProperties + public sealed class MqttPublishPacketProperties { public MqttPayloadFormatIndicator? PayloadFormatIndicator { get; set; } diff --git a/Source/MQTTnet/Packets/MqttSubAckPacket.cs b/Source/MQTTnet/Packets/MqttSubAckPacket.cs index 4600a4e..e3848a9 100644 --- a/Source/MQTTnet/Packets/MqttSubAckPacket.cs +++ b/Source/MQTTnet/Packets/MqttSubAckPacket.cs @@ -4,9 +4,9 @@ using MQTTnet.Protocol; namespace MQTTnet.Packets { - public class MqttSubAckPacket : MqttBasePacket, IMqttPacketWithIdentifier + public sealed class MqttSubAckPacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort? PacketIdentifier { get; set; } + public ushort PacketIdentifier { get; set; } public List ReturnCodes { get; set; } = new List(); diff --git a/Source/MQTTnet/Packets/MqttSubAckPacketProperties.cs b/Source/MQTTnet/Packets/MqttSubAckPacketProperties.cs index 74d3e7f..a8a2f57 100644 --- a/Source/MQTTnet/Packets/MqttSubAckPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttSubAckPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttSubAckPacketProperties + public sealed class MqttSubAckPacketProperties { public string ReasonString { get; set; } diff --git a/Source/MQTTnet/Packets/MqttSubscribePacket.cs b/Source/MQTTnet/Packets/MqttSubscribePacket.cs index 94061f0..d7f4da8 100644 --- a/Source/MQTTnet/Packets/MqttSubscribePacket.cs +++ b/Source/MQTTnet/Packets/MqttSubscribePacket.cs @@ -3,9 +3,9 @@ using System.Linq; namespace MQTTnet.Packets { - public class MqttSubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier + public sealed class MqttSubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort? PacketIdentifier { get; set; } + public ushort PacketIdentifier { get; set; } public List TopicFilters { get; set; } = new List(); diff --git a/Source/MQTTnet/Packets/MqttSubscribePacketProperties.cs b/Source/MQTTnet/Packets/MqttSubscribePacketProperties.cs index 34f58f2..43831ba 100644 --- a/Source/MQTTnet/Packets/MqttSubscribePacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttSubscribePacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttSubscribePacketProperties + public sealed class MqttSubscribePacketProperties { public uint? SubscriptionIdentifier { get; set; } diff --git a/Source/MQTTnet/Packets/MqttUnsubAckPacket.cs b/Source/MQTTnet/Packets/MqttUnsubAckPacket.cs index a17261e..394da48 100644 --- a/Source/MQTTnet/Packets/MqttUnsubAckPacket.cs +++ b/Source/MQTTnet/Packets/MqttUnsubAckPacket.cs @@ -4,9 +4,9 @@ using MQTTnet.Protocol; namespace MQTTnet.Packets { - public class MqttUnsubAckPacket : MqttBasePacket, IMqttPacketWithIdentifier + public sealed class MqttUnsubAckPacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort? PacketIdentifier { get; set; } + public ushort PacketIdentifier { get; set; } #region Added in MQTTv5 diff --git a/Source/MQTTnet/Packets/MqttUnsubAckPacketProperties.cs b/Source/MQTTnet/Packets/MqttUnsubAckPacketProperties.cs index 2a102c6..89ac074 100644 --- a/Source/MQTTnet/Packets/MqttUnsubAckPacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttUnsubAckPacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttUnsubAckPacketProperties + public sealed class MqttUnsubAckPacketProperties { public string ReasonString { get; set; } diff --git a/Source/MQTTnet/Packets/MqttUnsubscribePacket.cs b/Source/MQTTnet/Packets/MqttUnsubscribePacket.cs index 7b07a24..76a9d37 100644 --- a/Source/MQTTnet/Packets/MqttUnsubscribePacket.cs +++ b/Source/MQTTnet/Packets/MqttUnsubscribePacket.cs @@ -2,9 +2,9 @@ namespace MQTTnet.Packets { - public class MqttUnsubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier + public sealed class MqttUnsubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier { - public ushort? PacketIdentifier { get; set; } + public ushort PacketIdentifier { get; set; } public List TopicFilters { get; set; } = new List(); diff --git a/Source/MQTTnet/Packets/MqttUnsubscribePacketProperties.cs b/Source/MQTTnet/Packets/MqttUnsubscribePacketProperties.cs index 3ec68f1..d98a4f9 100644 --- a/Source/MQTTnet/Packets/MqttUnsubscribePacketProperties.cs +++ b/Source/MQTTnet/Packets/MqttUnsubscribePacketProperties.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttUnsubscribePacketProperties + public sealed class MqttUnsubscribePacketProperties { public List UserProperties { get; set; } } diff --git a/Source/MQTTnet/Packets/MqttUserProperty.cs b/Source/MQTTnet/Packets/MqttUserProperty.cs index dbfae07..f2d2fa1 100644 --- a/Source/MQTTnet/Packets/MqttUserProperty.cs +++ b/Source/MQTTnet/Packets/MqttUserProperty.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Packets { - public class MqttUserProperty + public sealed class MqttUserProperty { public MqttUserProperty(string name, string value) { diff --git a/Source/MQTTnet/Server/IMqttServerOptions.cs b/Source/MQTTnet/Server/IMqttServerOptions.cs index 05dab5e..233048e 100644 --- a/Source/MQTTnet/Server/IMqttServerOptions.cs +++ b/Source/MQTTnet/Server/IMqttServerOptions.cs @@ -12,6 +12,7 @@ namespace MQTTnet.Server MqttPendingMessagesOverflowStrategy PendingMessagesOverflowStrategy { get; } TimeSpan DefaultCommunicationTimeout { get; } + TimeSpan KeepAliveMonitorInterval { get; } IMqttServerConnectionValidator ConnectionValidator { get; } IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; } diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index d501c8e..362af39 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -23,7 +23,6 @@ namespace MQTTnet.Server readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); readonly IMqttRetainedMessagesManager _retainedMessagesManager; - readonly MqttClientKeepAliveMonitor _keepAliveMonitor; readonly MqttClientSessionsManager _sessionsManager; readonly IMqttNetScopedLogger _logger; @@ -36,7 +35,6 @@ namespace MQTTnet.Server readonly DateTime _connectedTimestamp; volatile Task _packageReceiverTask; - DateTime _lastPacketReceivedTimestamp; DateTime _lastNonKeepAlivePacketReceivedTimestamp; long _receivedPacketsCount; @@ -69,17 +67,19 @@ namespace MQTTnet.Server if (logger == null) throw new ArgumentNullException(nameof(logger)); _logger = logger.CreateScopedLogger(nameof(MqttClientConnection)); - _keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, () => StopAsync(), logger); - _connectedTimestamp = DateTime.UtcNow; - _lastPacketReceivedTimestamp = _connectedTimestamp; - _lastNonKeepAlivePacketReceivedTimestamp = _lastPacketReceivedTimestamp; + LastPacketReceivedTimestamp = _connectedTimestamp; + _lastNonKeepAlivePacketReceivedTimestamp = LastPacketReceivedTimestamp; } public MqttConnectPacket ConnectPacket { get; } public string ClientId => ConnectPacket.ClientId; + public bool IsReadingPacket => _channelAdapter.IsReadingPacket; + + public DateTime LastPacketReceivedTimestamp { get; private set; } + public MqttClientSession Session { get; } public Task StopAsync(bool isTakeover = false) @@ -115,7 +115,7 @@ namespace MQTTnet.Server status.SentPacketsCount = Interlocked.Read(ref _sentPacketsCount); status.ConnectedTimestamp = _connectedTimestamp; - status.LastPacketReceivedTimestamp = _lastPacketReceivedTimestamp; + status.LastPacketReceivedTimestamp = LastPacketReceivedTimestamp; status.LastNonKeepAlivePacketReceivedTimestamp = _lastNonKeepAlivePacketReceivedTimestamp; status.BytesSent = _channelAdapter.BytesSent; @@ -140,16 +140,10 @@ namespace MQTTnet.Server { _logger.Info("Client '{0}': Session started.", ClientId); - _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; - _channelAdapter.ReadingPacketCompletedCallback = OnAdapterReadingPacketCompleted; - Session.WillMessage = ConnectPacket.WillMessage; Task.Run(() => SendPendingPacketsAsync(_cancellationToken.Token), _cancellationToken.Token).Forget(_logger); - // TODO: Change to single thread in SessionManager. Or use SessionManager and stats from KeepAliveMonitor. - _keepAliveMonitor.Start(ConnectPacket.KeepAlivePeriod, _cancellationToken.Token); - await SendAsync(_channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(_connectionValidatorContext)).ConfigureAwait(false); Session.IsCleanSession = false; @@ -164,15 +158,13 @@ namespace MQTTnet.Server } Interlocked.Increment(ref _sentPacketsCount); - _lastPacketReceivedTimestamp = DateTime.UtcNow; + LastPacketReceivedTimestamp = DateTime.UtcNow; if (!(packet is MqttPingReqPacket || packet is MqttPingRespPacket)) { - _lastNonKeepAlivePacketReceivedTimestamp = _lastPacketReceivedTimestamp; + _lastNonKeepAlivePacketReceivedTimestamp = LastPacketReceivedTimestamp; } - _keepAliveMonitor.PacketReceived(); - if (packet is MqttPublishPacket publishPacket) { await HandleIncomingPublishPacketAsync(publishPacket).ConfigureAwait(false); @@ -252,9 +244,6 @@ namespace MQTTnet.Server _packetDispatcher.Reset(); - _channelAdapter.ReadingPacketStartedCallback = null; - _channelAdapter.ReadingPacketCompletedCallback = null; - _packageReceiverTask = null; if (_isTakeover) @@ -302,10 +291,10 @@ namespace MQTTnet.Server async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) { - // TODO: Let the channel adapter create the packet. var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); + var subAckPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreateSubAckPacket(subscribePacket, subscribeResult); - await SendAsync(subscribeResult.ResponsePacket).ConfigureAwait(false); + await SendAsync(subAckPacket).ConfigureAwait(false); if (subscribeResult.CloseConnection) { @@ -318,9 +307,10 @@ namespace MQTTnet.Server async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket) { - // TODO: Let the channel adapter create the packet. - var unsubscribeResult = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); - await SendAsync(unsubscribeResult).ConfigureAwait(false); + var reasonCodes = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); + var unsubAckPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreateUnsubAckPacket(unsubscribePacket, reasonCodes); + + await SendAsync(unsubAckPacket).ConfigureAwait(false); } Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) @@ -389,7 +379,7 @@ namespace MQTTnet.Server { while (!cancellationToken.IsCancellationRequested) { - queuedApplicationMessage = await Session.ApplicationMessagesQueue.TakeAsync(cancellationToken).ConfigureAwait(false); + queuedApplicationMessage = await Session.ApplicationMessagesQueue.DequeueAsync(cancellationToken).ConfigureAwait(false); if (queuedApplicationMessage == null) { return; @@ -503,15 +493,5 @@ namespace MQTTnet.Server Interlocked.Increment(ref _receivedApplicationMessagesCount); } } - - void OnAdapterReadingPacketCompleted() - { - _keepAliveMonitor?.Resume(); - } - - void OnAdapterReadingPacketStarted() - { - _keepAliveMonitor?.Pause(); - } } } diff --git a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs deleted file mode 100644 index c39be49..0000000 --- a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs +++ /dev/null @@ -1,93 +0,0 @@ -using MQTTnet.Diagnostics; -using MQTTnet.Internal; -using System; -using System.Diagnostics; -using System.Threading; -using System.Threading.Tasks; - -namespace MQTTnet.Server -{ - public sealed class MqttClientKeepAliveMonitor - { - readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch(); - - readonly string _clientId; - readonly Func _keepAliveElapsedCallback; - readonly IMqttNetScopedLogger _logger; - - bool _isPaused; - - public MqttClientKeepAliveMonitor(string clientId, Func keepAliveElapsedCallback, IMqttNetLogger logger) - { - _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); - _keepAliveElapsedCallback = keepAliveElapsedCallback ?? throw new ArgumentNullException(nameof(keepAliveElapsedCallback)); - - if (logger == null) throw new ArgumentNullException(nameof(logger)); - _logger = logger.CreateScopedLogger(nameof(MqttClientKeepAliveMonitor)); - } - - public void Start(int keepAlivePeriod, CancellationToken cancellationToken) - { - if (keepAlivePeriod == 0) - { - return; - } - - Task.Run(() => RunAsync(keepAlivePeriod, cancellationToken), cancellationToken).Forget(_logger); - } - - public void Pause() - { - _isPaused = true; - } - - public void Resume() - { - _isPaused = false; - } - - public void PacketReceived() - { - _lastPacketReceivedTracker.Restart(); - } - - async Task RunAsync(int keepAlivePeriod, CancellationToken cancellationToken) - { - try - { - _lastPacketReceivedTracker.Restart(); - - while (!cancellationToken.IsCancellationRequested) - { - // Values described here: [MQTT-3.1.2-24]. - // If the client sends 5 sec. the server will allow up to 7.5 seconds. - // If the client sends 1 sec. the server will allow up to 1.5 seconds. - if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds >= keepAlivePeriod * 1.5D) - { - _logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientId); - await _keepAliveElapsedCallback().ConfigureAwait(false); - - return; - } - - // The server checks the keep alive timeout every 50 % of the overall keep alive timeout - // because the server allows 1.5 times the keep alive value. This means that a value of 5 allows - // up to 7.5 seconds. With an interval of 2.5 (5 / 2) the 7.5 is also affected. Waiting the whole - // keep alive time will hit at 10 instead of 7.5 (but only one time instead of two times). - await Task.Delay(TimeSpan.FromSeconds(keepAlivePeriod * 0.5D), cancellationToken).ConfigureAwait(false); - } - } - catch (OperationCanceledException) - { - } - catch (Exception exception) - { - _logger.Error(exception, "Client '{0}': Unhandled exception while checking keep alive timeouts.", _clientId); - } - finally - { - _logger.Verbose("Client '{0}': Stopped checking keep alive timeout.", _clientId); - } - } - } -} diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 274eccf..81f648e 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; namespace MQTTnet.Server { - public class MqttClientSession + public sealed class MqttClientSession { readonly IMqttNetScopedLogger _logger; diff --git a/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs b/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs index 19c4f93..6ddba92 100644 --- a/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs +++ b/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs @@ -6,11 +6,11 @@ using System.Threading.Tasks; namespace MQTTnet.Server { - public class MqttClientSessionApplicationMessagesQueue : Disposable + public sealed class MqttClientSessionApplicationMessagesQueue : IDisposable { - private readonly AsyncQueue _messageQueue = new AsyncQueue(); - - private readonly IMqttServerOptions _options; + readonly AsyncQueue _messageQueue = new AsyncQueue(); + + readonly IMqttServerOptions _options; public MqttClientSessionApplicationMessagesQueue(IMqttServerOptions options) { @@ -32,22 +32,6 @@ namespace MQTTnet.Server }); } - public void Clear() - { - _messageQueue.Clear(); - } - - public async Task TakeAsync(CancellationToken cancellationToken) - { - var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false); - if (!dequeueResult.IsSuccess) - { - return null; - } - - return dequeueResult.Item; - } - public void Enqueue(MqttQueuedApplicationMessage queuedApplicationMessage) { if (queuedApplicationMessage == null) throw new ArgumentNullException(nameof(queuedApplicationMessage)); @@ -71,14 +55,25 @@ namespace MQTTnet.Server } } - protected override void Dispose(bool disposing) + public async Task DequeueAsync(CancellationToken cancellationToken) { - if (disposing) + var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false); + if (!dequeueResult.IsSuccess) { - _messageQueue.Dispose(); + return null; } - base.Dispose(disposing); + return dequeueResult.Item; + } + + public void Clear() + { + _messageQueue.Clear(); + } + + public void Dispose() + { + _messageQueue?.Dispose(); } } } diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 19c8478..05c614e 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -9,6 +9,7 @@ using MQTTnet.Server.Status; using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -63,6 +64,11 @@ namespace MQTTnet.Server } } + public List GetConnections() + { + return _connections.Values.ToList(); + } + public Task HandleClientConnectionAsync(IMqttChannelAdapter clientAdapter) { if (clientAdapter is null) throw new ArgumentNullException(nameof(clientAdapter)); diff --git a/Source/MQTTnet/Server/MqttClientSubscribeResult.cs b/Source/MQTTnet/Server/MqttClientSubscribeResult.cs index fe75c3b..f25d6c7 100644 --- a/Source/MQTTnet/Server/MqttClientSubscribeResult.cs +++ b/Source/MQTTnet/Server/MqttClientSubscribeResult.cs @@ -1,10 +1,13 @@ -using MQTTnet.Packets; +using System.Collections.Generic; +using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttClientSubscribeResult + public sealed class MqttClientSubscribeResult { - public MqttSubAckPacket ResponsePacket { get; set; } + public List ReturnCodes { get; } = new List(); + + public List ReasonCodes { get; } = new List(); public bool CloseConnection { get; set; } } diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index 48f5bc9..2d21629 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -7,12 +7,12 @@ using System.Threading.Tasks; namespace MQTTnet.Server { - public class MqttClientSubscriptionsManager + public sealed class MqttClientSubscriptionsManager { - private readonly Dictionary _subscriptions = new Dictionary(); - private readonly MqttClientSession _clientSession; - private readonly IMqttServerOptions _serverOptions; - private readonly MqttServerEventDispatcher _eventDispatcher; + readonly Dictionary _subscriptions = new Dictionary(); + readonly MqttClientSession _clientSession; + readonly IMqttServerOptions _serverOptions; + readonly MqttServerEventDispatcher _eventDispatcher; public MqttClientSubscriptionsManager(MqttClientSession clientSession, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions) { @@ -28,15 +28,7 @@ namespace MQTTnet.Server if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); - var result = new MqttClientSubscribeResult - { - ResponsePacket = new MqttSubAckPacket - { - PacketIdentifier = subscribePacket.PacketIdentifier - }, - - CloseConnection = false - }; + var result = new MqttClientSubscribeResult(); foreach (var originalTopicFilter in subscribePacket.TopicFilters) { @@ -46,13 +38,13 @@ namespace MQTTnet.Server if (finalTopicFilter == null || string.IsNullOrEmpty(finalTopicFilter.Topic) || !interceptorContext.AcceptSubscription) { - result.ResponsePacket.ReturnCodes.Add(MqttSubscribeReturnCode.Failure); - result.ResponsePacket.ReasonCodes.Add(MqttSubscribeReasonCode.UnspecifiedError); + result.ReturnCodes.Add(MqttSubscribeReturnCode.Failure); + result.ReasonCodes.Add(MqttSubscribeReasonCode.UnspecifiedError); } else { - result.ResponsePacket.ReturnCodes.Add(ConvertToSubscribeReturnCode(finalTopicFilter.QualityOfServiceLevel)); - result.ResponsePacket.ReasonCodes.Add(ConvertToSubscribeReasonCode(finalTopicFilter.QualityOfServiceLevel)); + result.ReturnCodes.Add(ConvertToSubscribeReturnCode(finalTopicFilter.QualityOfServiceLevel)); + result.ReasonCodes.Add(ConvertToSubscribeReasonCode(finalTopicFilter.QualityOfServiceLevel)); } if (interceptorContext.CloseConnection) @@ -98,21 +90,18 @@ namespace MQTTnet.Server } } - public async Task UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket) + public async Task> UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket) { if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket)); - var unsubAckPacket = new MqttUnsubAckPacket - { - PacketIdentifier = unsubscribePacket.PacketIdentifier - }; + var reasonCodes = new List(); foreach (var topicFilter in unsubscribePacket.TopicFilters) { var interceptorContext = await InterceptUnsubscribeAsync(topicFilter).ConfigureAwait(false); if (!interceptorContext.AcceptUnsubscription) { - unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.ImplementationSpecificError); + reasonCodes.Add(MqttUnsubscribeReasonCode.ImplementationSpecificError); continue; } @@ -120,11 +109,11 @@ namespace MQTTnet.Server { if (_subscriptions.Remove(topicFilter)) { - unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.Success); + reasonCodes.Add(MqttUnsubscribeReasonCode.Success); } else { - unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.NoSubscriptionExisted); + reasonCodes.Add(MqttUnsubscribeReasonCode.NoSubscriptionExisted); } } } @@ -134,7 +123,7 @@ namespace MQTTnet.Server await _eventDispatcher.SafeNotifyClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); } - return unsubAckPacket; + return reasonCodes; } public async Task UnsubscribeAsync(IEnumerable topicFilters) @@ -184,7 +173,7 @@ namespace MQTTnet.Server return CreateSubscriptionResult(qosLevel, qosLevels); } - private static MqttSubscribeReturnCode ConvertToSubscribeReturnCode(MqttQualityOfServiceLevel qualityOfServiceLevel) + static MqttSubscribeReturnCode ConvertToSubscribeReturnCode(MqttQualityOfServiceLevel qualityOfServiceLevel) { switch (qualityOfServiceLevel) { @@ -195,7 +184,7 @@ namespace MQTTnet.Server } } - private static MqttSubscribeReasonCode ConvertToSubscribeReasonCode(MqttQualityOfServiceLevel qualityOfServiceLevel) + static MqttSubscribeReasonCode ConvertToSubscribeReasonCode(MqttQualityOfServiceLevel qualityOfServiceLevel) { switch (qualityOfServiceLevel) { @@ -206,7 +195,7 @@ namespace MQTTnet.Server } } - private async Task InterceptSubscribeAsync(MqttTopicFilter topicFilter) + async Task InterceptSubscribeAsync(MqttTopicFilter topicFilter) { var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items); if (_serverOptions.SubscriptionInterceptor != null) @@ -217,7 +206,7 @@ namespace MQTTnet.Server return context; } - private async Task InterceptUnsubscribeAsync(string topicFilter) + async Task InterceptUnsubscribeAsync(string topicFilter) { var context = new MqttUnsubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items); if (_serverOptions.UnsubscriptionInterceptor != null) @@ -228,7 +217,7 @@ namespace MQTTnet.Server return context; } - private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet subscribedQoSLevels) + static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet subscribedQoSLevels) { MqttQualityOfServiceLevel effectiveQoS; if (subscribedQoSLevels.Contains(qosLevel)) diff --git a/Source/MQTTnet/Server/MqttServerKeepAliveMonitor.cs b/Source/MQTTnet/Server/MqttServerKeepAliveMonitor.cs new file mode 100644 index 0000000..4664f87 --- /dev/null +++ b/Source/MQTTnet/Server/MqttServerKeepAliveMonitor.cs @@ -0,0 +1,115 @@ +using MQTTnet.Diagnostics; +using MQTTnet.Internal; +using System; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Implementations; + +namespace MQTTnet.Server +{ + public sealed class MqttServerKeepAliveMonitor + { + readonly IMqttServerOptions _options; + readonly MqttClientSessionsManager _sessionsManager; + readonly IMqttNetScopedLogger _logger; + + public MqttServerKeepAliveMonitor(IMqttServerOptions options, MqttClientSessionsManager sessionsManager, IMqttNetLogger logger) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); + + if (logger == null) throw new ArgumentNullException(nameof(logger)); + _logger = logger.CreateScopedLogger(nameof(MqttServerKeepAliveMonitor)); + } + + public void Start(CancellationToken cancellationToken) + { + // The keep alive monitor spawns a real new thread (LongRunning) because it does not + // support async/await. Async etc. is avoided here because the thread will usually check + // the connections every few milliseconds and thus the context changes (due to async) are + // only consuming resources. Also there is just 1 thread for the entire server which is fine at all! + Task.Factory.StartNew(_ => DoWork(cancellationToken), cancellationToken, TaskCreationOptions.LongRunning).Forget(_logger); + } + + void DoWork(CancellationToken cancellationToken) + { + try + { + _logger.Info("Starting keep alive monitor."); + + while (!cancellationToken.IsCancellationRequested) + { + TryMaintainConnections(); + PlatformAbstractionLayer.Sleep(_options.KeepAliveMonitorInterval); + } + } + catch (OperationCanceledException) + { + } + catch (Exception exception) + { + _logger.Error(exception, "Unhandled exception while checking keep alive timeouts."); + } + finally + { + _logger.Verbose("Stopped checking keep alive timeout."); + } + } + + void TryMaintainConnections() + { + var now = DateTime.UtcNow; + foreach (var connection in _sessionsManager.GetConnections()) + { + TryMaintainConnection(connection, now); + } + } + + void TryMaintainConnection(MqttClientConnection connection, DateTime now) + { + try + { + //if (connection.IsStopped) + //{ + // // The connection is already dead so there is no need to check it. + // return; + //} + + if (connection.ConnectPacket.KeepAlivePeriod == 0) + { + // The keep alive feature is not used by the current connection. + return; + } + + if (connection.IsReadingPacket) + { + // The connection is currently reading a (large) packet. So it is obviously + // doing something and thus "connected". + return; + } + + // Values described here: [MQTT-3.1.2-24]. + // If the client sends 5 sec. the server will allow up to 7.5 seconds. + // If the client sends 1 sec. the server will allow up to 1.5 seconds. + var maxDurationWithoutPacket = connection.ConnectPacket.KeepAlivePeriod * 1.5D; + + var secondsWithoutPackage = (now - connection.LastPacketReceivedTimestamp).TotalSeconds; + if (secondsWithoutPackage < maxDurationWithoutPacket) + { + // A packet was received before the timeout is affected. + return; + } + + _logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", connection.ClientId); + + // Execute the disconnection in background so that the keep alive monitor can continue + // with checking other connections. + Task.Run(() => connection.StopAsync()); + } + catch (Exception exception) + { + _logger.Error(exception, "Client {0}: Unhandled exception while checking keep alive timeouts.", connection.ClientId); + } + } + } +} diff --git a/Source/MQTTnet/Server/MqttServerOptions.cs b/Source/MQTTnet/Server/MqttServerOptions.cs index 45de09c..31f56e5 100644 --- a/Source/MQTTnet/Server/MqttServerOptions.cs +++ b/Source/MQTTnet/Server/MqttServerOptions.cs @@ -18,6 +18,8 @@ namespace MQTTnet.Server public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15); + public TimeSpan KeepAliveMonitorInterval { get; set; } = TimeSpan.FromMilliseconds(500); + public IMqttServerConnectionValidator ConnectionValidator { get; set; } public IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; set; } diff --git a/Tests/MQTTnet.Core.Tests/AsyncLock_Tests.cs b/Tests/MQTTnet.Core.Tests/AsyncLock_Tests.cs index ad1d10c..a4e6670 100644 --- a/Tests/MQTTnet.Core.Tests/AsyncLock_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/AsyncLock_Tests.cs @@ -1,4 +1,5 @@ -using System.Threading; +using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Internal; @@ -9,7 +10,65 @@ namespace MQTTnet.Tests public class AsyncLock_Tests { [TestMethod] - public void AsyncLock() + public async Task Lock_Serial_Calls() + { + var sum = 0; + + var @lock = new AsyncLock(); + for (var i = 0; i < 100; i++) + { + using (await @lock.WaitAsync().ConfigureAwait(false)) + { + sum++; + } + } + + Assert.AreEqual(100, sum); + } + + [TestMethod] + [ExpectedException(typeof(TaskCanceledException))] + public async Task Test_Cancellation() + { + var @lock = new AsyncLock(); + + // This call will never "release" the lock due to missing _using_. + await @lock.WaitAsync().ConfigureAwait(false); + + using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + { + await @lock.WaitAsync(cts.Token).ConfigureAwait(false); + } + } + + //[TestMethod] + //public async Task Test_Cancellation_With_Later_Access() + //{ + // var @lock = new AsyncLock(); + + // var releaser = await @lock.WaitAsync().ConfigureAwait(false); + + // try + // { + // using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + // { + // await @lock.WaitAsync(cts.Token).ConfigureAwait(false); + // } + // } + // catch (OperationCanceledException) + // { + // } + + // releaser.Dispose(); + + // using (await @lock.WaitAsync().ConfigureAwait(false)) + // { + // // When the method finished, the thread got access. + // } + //} + + [TestMethod] + public void Lock_10_Parallel_Tasks() { const int ThreadsCount = 10; diff --git a/Tests/MQTTnet.Core.Tests/AsyncQueue_Tests.cs b/Tests/MQTTnet.Core.Tests/AsyncQueue_Tests.cs index c222cdf..fb1ad93 100644 --- a/Tests/MQTTnet.Core.Tests/AsyncQueue_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/AsyncQueue_Tests.cs @@ -1,4 +1,5 @@ -using System.Threading; +using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Internal; @@ -36,9 +37,22 @@ namespace MQTTnet.Tests Assert.AreEqual(3, queue.Count); } + [TestMethod] + public async Task Cancellation() + { + var queue = new AsyncQueue(); + + bool success; + using (var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + { + success = (await queue.TryDequeueAsync(cancellationTokenSource.Token)).IsSuccess; + } + + Assert.AreEqual(false, success); + } [TestMethod] - public async Task Preserve_ProcessAsync() + public async Task Process_Async() { var queue = new AsyncQueue(); @@ -50,7 +64,37 @@ namespace MQTTnet.Tests sum += (await queue.TryDequeueAsync(CancellationToken.None)).Item; } }); - + + queue.Enqueue(1); + await Task.Delay(500); + + queue.Enqueue(2); + await Task.Delay(500); + + queue.Enqueue(3); + await Task.Delay(500); + + Assert.AreEqual(6, sum); + Assert.AreEqual(TaskStatus.RanToCompletion, worker.Status); + } + + [TestMethod] + public async Task Process_Async_With_Initial_Delay() + { + var queue = new AsyncQueue(); + + var sum = 0; + var worker = Task.Run(async () => + { + while (sum < 6) + { + sum += (await queue.TryDequeueAsync(CancellationToken.None)).Item; + } + }); + + // This line is the diff to test _Process_Async_ + await Task.Delay(500); + queue.Enqueue(1); await Task.Delay(500); diff --git a/Tests/MQTTnet.Core.Tests/LowLevelMqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/LowLevelMqttClient_Tests.cs index 478df34..7aa9cd3 100644 --- a/Tests/MQTTnet.Core.Tests/LowLevelMqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/LowLevelMqttClient_Tests.cs @@ -1,10 +1,12 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client.Options; using MQTTnet.LowLevelClient; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Tests.Mockups; using System.Collections.Generic; +using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -34,7 +36,7 @@ namespace MQTTnet.Tests { using (var testEnvironment = new TestEnvironment(TestContext)) { - var server = await testEnvironment.StartServerAsync(); + await testEnvironment.StartServerAsync(); var factory = new MqttFactory(); var lowLevelClient = factory.CreateLowLevelMqttClient(); @@ -50,7 +52,7 @@ namespace MQTTnet.Tests { using (var testEnvironment = new TestEnvironment(TestContext)) { - var server = await testEnvironment.StartServerAsync(); + await testEnvironment.StartServerAsync(); var factory = new MqttFactory(); var lowLevelClient = factory.CreateLowLevelMqttClient(); @@ -71,7 +73,7 @@ namespace MQTTnet.Tests { using (var testEnvironment = new TestEnvironment(TestContext)) { - var server = await testEnvironment.StartServerAsync(); + await testEnvironment.StartServerAsync(); var factory = new MqttFactory(); var lowLevelClient = factory.CreateLowLevelMqttClient(); @@ -89,6 +91,39 @@ namespace MQTTnet.Tests } } + [TestMethod] + public async Task Loose_Connection() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + testEnvironment.ServerPort = 8364; + var server = await testEnvironment.StartServerAsync(); + var client = await testEnvironment.ConnectLowLevelClientAsync(o => o.WithCommunicationTimeout(TimeSpan.Zero)); + + await Authenticate(client).ConfigureAwait(false); + + await server.StopAsync(); + + await Task.Delay(1000); + + try + { + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None).ConfigureAwait(false); + } + catch (MqttCommunicationException exception) + { + Assert.IsTrue(exception.InnerException is SocketException); + return; + } + catch + { + Assert.Fail("Wrong exception type thrown."); + } + + Assert.Fail("This MUST fail"); + } + } + async Task Authenticate(ILowLevelMqttClient client) { await client.SendAsync(new MqttConnectPacket diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs index 3fb399e..ca6988a 100644 --- a/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs @@ -6,7 +6,9 @@ using MQTTnet.Server; using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; +using MQTTnet.LowLevelClient; namespace MQTTnet.Tests.Mockups { @@ -100,6 +102,25 @@ namespace MQTTnet.Tests.Mockups return ConnectClientAsync(new MqttClientOptionsBuilder()); } + public Task ConnectLowLevelClientAsync() + { + return ConnectLowLevelClientAsync(o => {}); + } + + public async Task ConnectLowLevelClientAsync(Action optionsBuilder) + { + if (optionsBuilder == null) throw new ArgumentNullException(nameof(optionsBuilder)); + + var options = new MqttClientOptionsBuilder(); + options = options.WithTcpServer("127.0.0.1", ServerPort); + optionsBuilder.Invoke(options); + + var client = new MqttFactory().CreateLowLevelMqttClient(); + await client.ConnectAsync(options.Build(), CancellationToken.None).ConfigureAwait(false); + + return client; + } + public async Task ConnectClientAsync(MqttClientOptionsBuilder options) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -111,7 +132,7 @@ namespace MQTTnet.Tests.Mockups return client; } - + public async Task ConnectClientAsync(IMqttClientOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs index 07fff2a..833160f 100644 --- a/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Tests.Mockups { public class TestMqttCommunicationAdapter : IMqttChannelAdapter { - private readonly BlockingCollection _incomingPackets = new BlockingCollection(); + readonly BlockingCollection _incomingPackets = new BlockingCollection(); public TestMqttCommunicationAdapter Partner { get; set; } @@ -26,8 +26,7 @@ namespace MQTTnet.Tests.Mockups public long BytesSent { get; } public long BytesReceived { get; } - public Action ReadingPacketStartedCallback { get; set; } - public Action ReadingPacketCompletedCallback { get; set; } + public bool IsReadingPacket { get; } public void Dispose() { diff --git a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitor_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitor_Tests.cs deleted file mode 100644 index e275e97..0000000 --- a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitor_Tests.cs +++ /dev/null @@ -1,118 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Adapter; -using MQTTnet.Diagnostics; -using MQTTnet.Server; -using MQTTnet.Server.Status; -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - -namespace MQTTnet.Tests -{ - [TestClass] - public class MqttKeepAliveMonitor_Tests - { - [TestMethod] - public async Task KeepAlive_Timeout() - { - var counter = 0; - - var monitor = new MqttClientKeepAliveMonitor("", () => - { - counter++; - return Task.CompletedTask; - }, - new MqttNetLogger()); - - Assert.AreEqual(0, counter); - - monitor.Start(1, CancellationToken.None); - - Assert.AreEqual(0, counter); - - await Task.Delay(2000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification. - - Assert.AreEqual(1, counter); - } - - [TestMethod] - public async Task KeepAlive_NoTimeout() - { - var counter = 0; - - var monitor = new MqttClientKeepAliveMonitor("", () => - { - counter++; - return Task.CompletedTask; - }, - new MqttNetLogger()); - - Assert.AreEqual(0, counter); - - monitor.Start(1, CancellationToken.None); - - Assert.AreEqual(0, counter); - - // Simulate traffic. - await Task.Delay(1000); // Internally the keep alive timeout is multiplied with 1.5 as per protocol specification. - monitor.PacketReceived(); - await Task.Delay(1000); - monitor.PacketReceived(); - await Task.Delay(1000); - - Assert.AreEqual(0, counter); - - await Task.Delay(2000); - - Assert.AreEqual(1, counter); - } - - private class TestClientSession : IMqttClientSession - { - public string ClientId { get; } - - public int StopCalledCount { get; private set; } - - public void FillStatus(MqttClientStatus status) - { - throw new NotSupportedException(); - } - - public void EnqueueApplicationMessage(MqttClientConnection senderClientSession, MqttApplicationMessage applicationMessage) - { - throw new NotSupportedException(); - } - - public void ClearPendingApplicationMessages() - { - throw new NotSupportedException(); - } - - public Task RunAsync(MqttApplicationMessage willMessage, int keepAlivePeriod, IMqttChannelAdapter adapter) - { - throw new NotSupportedException(); - } - - public Task StopAsync() - { - StopCalledCount++; - return Task.FromResult(0); - } - - public Task SubscribeAsync(IList topicFilters) - { - throw new NotSupportedException(); - } - - public Task UnsubscribeAsync(IList topicFilters) - { - throw new NotSupportedException(); - } - - public void Dispose() - { - } - } - } -} diff --git a/Tests/MQTTnet.Core.Tests/Server_Connection_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Connection_Tests.cs new file mode 100644 index 0000000..ebd64d8 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Server_Connection_Tests.cs @@ -0,0 +1,78 @@ +using System; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Implementations; +using MQTTnet.Server; +using MQTTnet.Tests.Mockups; + +namespace MQTTnet.Tests +{ + [TestClass] + public sealed class Server_Connection_Tests + { + [TestMethod] + public async Task Close_Idle_Connection_On_Connect() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); + + var client = new CrossPlatformSocket(AddressFamily.InterNetwork); + await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); + + // Don't send anything. The server should close the connection. + await Task.Delay(TimeSpan.FromSeconds(3)); + + try + { + var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + if (receivedBytes == 0) + { + return; + } + + Assert.Fail("Receive should throw an exception."); + } + catch (SocketException) + { + } + } + } + + [TestMethod] + public async Task Send_Garbage() + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); + + // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state + // forever. This is security related. + var client = new CrossPlatformSocket(AddressFamily.InterNetwork); + await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); + + var buffer = Encoding.UTF8.GetBytes("Garbage"); + await client.SendAsync(new ArraySegment(buffer), SocketFlags.None); + + await Task.Delay(TimeSpan.FromSeconds(3)); + + try + { + var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + if (receivedBytes == 0) + { + return; + } + + Assert.Fail("Receive should throw an exception."); + } + catch (SocketException) + { + } + } + } + } +} \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/Server_KeepAlive_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_KeepAlive_Tests.cs new file mode 100644 index 0000000..58b41c8 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Server_KeepAlive_Tests.cs @@ -0,0 +1,58 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Packets; +using MQTTnet.Tests.Mockups; + +namespace MQTTnet.Tests +{ + [TestClass] + public sealed class Server_KeepAlive_Tests + { + [TestMethod] + public async Task Disconnect_Client_DueTo_KeepAlive() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(); + + var client = await testEnvironment.ConnectLowLevelClientAsync(o => o.WithCommunicationTimeout(TimeSpan.FromSeconds(2))).ConfigureAwait(false); + + await client.SendAsync(new MqttConnectPacket + { + CleanSession = true, + ClientId = "abc", + KeepAlivePeriod = 1, + }, CancellationToken.None).ConfigureAwait(false); + + var response = await client.ReceiveAsync(CancellationToken.None).ConfigureAwait(false); + + Assert.IsTrue(response is MqttConnAckPacket); + + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None); + await Task.Delay(500); + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None); + await Task.Delay(500); + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None); + await Task.Delay(500); + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None); + + // If we reach this point everything works as expected (server did not close the connection + // due to proper ping messages. + // Now we will wait 1.2 seconds because the server MUST wait 1.5 seconds in total (See spec). + + await Task.Delay(1200); + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None); + + // Now we will wait longer than 1.5 so that the server will close the connection. + + await Task.Delay(3000); + + await server.StopAsync(); + + await client.ReceiveAsync(CancellationToken.None); + } + } + } +} \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs index 93c1d55..b93e3e4 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs @@ -6,7 +6,6 @@ using MQTTnet.Tests.Mockups; using MQTTnet.Client; using MQTTnet.Protocol; using MQTTnet.Server; -using System.Threading; namespace MQTTnet.Tests { @@ -63,7 +62,6 @@ namespace MQTTnet.Tests var clientStatus = await server.GetClientStatusAsync(); Assert.AreEqual(1, clientStatus.Count); - Assert.IsTrue(clientStatus.Any(s => s.ClientId == c1.Options.ClientId)); await clientStatus.First().DisconnectAsync(); diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs index 16db377..0e70a32 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -6,14 +6,12 @@ using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; using MQTTnet.Client.Receiving; using MQTTnet.Client.Subscribing; -using MQTTnet.Implementations; using MQTTnet.Protocol; using MQTTnet.Server; using MQTTnet.Tests.Mockups; using System; using System.Collections.Generic; using System.Linq; -using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -21,7 +19,7 @@ using System.Threading.Tasks; namespace MQTTnet.Tests { [TestClass] - public class Server_Tests + public sealed class Server_Tests { public TestContext TestContext { get; set; } @@ -331,7 +329,7 @@ namespace MQTTnet.Tests await Task.Delay(10); } - + var c2 = await testEnvironment.ConnectClientAsync(); var messageBuilder = new MqttApplicationMessageBuilder(); @@ -1219,68 +1217,6 @@ namespace MQTTnet.Tests } } - [TestMethod] - public async Task Close_Idle_Connection() - { - using (var testEnvironment = new TestEnvironment(TestContext)) - { - await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); - - var client = new CrossPlatformSocket(AddressFamily.InterNetwork); - await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); - - // Don't send anything. The server should close the connection. - await Task.Delay(TimeSpan.FromSeconds(3)); - - try - { - var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); - if (receivedBytes == 0) - { - return; - } - - Assert.Fail("Receive should throw an exception."); - } - catch (SocketException) - { - } - } - } - - [TestMethod] - public async Task Send_Garbage() - { - using (var testEnvironment = new TestEnvironment(TestContext)) - { - await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); - - // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state - // forever. This is security related. - var client = new CrossPlatformSocket(AddressFamily.InterNetwork); - await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); - - var buffer = Encoding.UTF8.GetBytes("Garbage"); - await client.SendAsync(new ArraySegment(buffer), SocketFlags.None); - - await Task.Delay(TimeSpan.FromSeconds(3)); - - try - { - var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); - if (receivedBytes == 0) - { - return; - } - - Assert.Fail("Receive should throw an exception."); - } - catch (SocketException) - { - } - } - } - [TestMethod] public async Task Do_Not_Send_Retained_Messages_For_Denied_Subscription() { @@ -1471,7 +1407,7 @@ namespace MQTTnet.Tests await client.SubscribeAsync(topic, MqttQualityOfServiceLevel.AtLeastOnce); - await client.PublishAsync(new MqttApplicationMessage{ Topic = topic, QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); + await client.PublishAsync(new MqttApplicationMessage { Topic = topic, QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); await Task.Delay(500);