Browse Source

Refactor keep alive message handling in order to fix issues with sudden disconnects.

release/3.x.x
Christian 3 years ago
parent
commit
b766c1bebb
7 changed files with 123 additions and 96 deletions
  1. +1
    -0
      MQTTnet.sln.DotSettings
  2. +31
    -29
      Source/MQTTnet/Client/MqttClient.cs
  3. +2
    -0
      Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs
  4. +11
    -5
      Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs
  5. +11
    -0
      Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs
  6. +58
    -60
      Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs
  7. +9
    -2
      Source/MQTTnet/Server/MqttClientConnection.cs

+ 1
- 0
MQTTnet.sln.DotSettings View File

@@ -12,4 +12,5 @@
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002ECSharpPlaceAttributeOnSameLineMigration/@EntryIndexedValue">True</s:Boolean> <s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002ECSharpPlaceAttributeOnSameLineMigration/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateBlankLinesAroundFieldToBlankLinesAroundProperty/@EntryIndexedValue">True</s:Boolean> <s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateBlankLinesAroundFieldToBlankLinesAroundProperty/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateThisQualifierSettings/@EntryIndexedValue">True</s:Boolean> <s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateThisQualifierSettings/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=PINGREQ/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unsub/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary> <s:Boolean x:Key="/Default/UserDictionary/Words/=unsub/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

+ 31
- 29
Source/MQTTnet/Client/MqttClient.cs View File

@@ -14,7 +14,6 @@ using MQTTnet.PacketDispatcher;
using MQTTnet.Packets; using MQTTnet.Packets;
using MQTTnet.Protocol; using MQTTnet.Protocol;
using System; using System;
using System.Diagnostics;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Implementations; using MQTTnet.Implementations;
@@ -25,8 +24,6 @@ namespace MQTTnet.Client
{ {
readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
readonly Stopwatch _sendTracker = new Stopwatch();
readonly Stopwatch _receiveTracker = new Stopwatch();
readonly object _disconnectLock = new object(); readonly object _disconnectLock = new object();


readonly IMqttClientAdapterFactory _adapterFactory; readonly IMqttClientAdapterFactory _adapterFactory;
@@ -44,6 +41,8 @@ namespace MQTTnet.Client
long _isDisconnectPending; long _isDisconnectPending;
bool _isConnected; bool _isConnected;
MqttClientDisconnectReason _disconnectReason; MqttClientDisconnectReason _disconnectReason;
DateTime _lastPacketSentTimestamp;


public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger) public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger)
{ {
@@ -79,7 +78,7 @@ namespace MQTTnet.Client
Options = options; Options = options;


_packetIdentifierProvider.Reset(); _packetIdentifierProvider.Reset();
_packetDispatcher.Cancel();
_packetDispatcher.CancelAll();


_backgroundCancellationTokenSource = new CancellationTokenSource(); _backgroundCancellationTokenSource = new CancellationTokenSource();
var backgroundCancellationToken = _backgroundCancellationTokenSource.Token; var backgroundCancellationToken = _backgroundCancellationTokenSource.Token;
@@ -102,8 +101,7 @@ namespace MQTTnet.Client
authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, combined.Token).ConfigureAwait(false); authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, combined.Token).ConfigureAwait(false);
} }


_sendTracker.Restart();
_receiveTracker.Restart();
_lastPacketSentTimestamp = DateTime.UtcNow;


if (Options.KeepAlivePeriod != TimeSpan.Zero) if (Options.KeepAlivePeriod != TimeSpan.Zero)
{ {
@@ -391,8 +389,8 @@ namespace MQTTnet.Client
{ {
cancellationToken.ThrowIfCancellationRequested(); cancellationToken.ThrowIfCancellationRequested();


_sendTracker.Restart();
_lastPacketSentTimestamp = DateTime.UtcNow;
return _adapter.SendPacketAsync(packet, cancellationToken); return _adapter.SendPacketAsync(packet, cancellationToken);
} }


@@ -400,18 +398,17 @@ namespace MQTTnet.Client
{ {
cancellationToken.ThrowIfCancellationRequested(); cancellationToken.ThrowIfCancellationRequested();


ushort identifier = 0;
if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0)
ushort packetIdentifier = 0;
if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier)
{ {
identifier = packetWithIdentifier.PacketIdentifier;
packetIdentifier = packetWithIdentifier.PacketIdentifier;
} }


using (var packetAwaiter = _packetDispatcher.AddAwaiter<TResponsePacket>(identifier))
using (var packetAwaiter = _packetDispatcher.AddAwaiter<TResponsePacket>(packetIdentifier))
{ {
try try
{ {
_sendTracker.Restart();
await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false);
await SendAsync(requestPacket, cancellationToken).ConfigureAwait(false);
} }
catch (Exception exception) catch (Exception exception)
{ {
@@ -446,15 +443,15 @@ namespace MQTTnet.Client
while (!cancellationToken.IsCancellationRequested) while (!cancellationToken.IsCancellationRequested)
{ {
// Values described here: [MQTT-3.1.2-24]. // Values described here: [MQTT-3.1.2-24].
var waitTime = keepAlivePeriod - _sendTracker.Elapsed;
var timeWithoutPacketSent = DateTime.UtcNow - _lastPacketSentTimestamp;


if (waitTime <= TimeSpan.Zero)
if (timeWithoutPacketSent > keepAlivePeriod)
{ {
await SendAndReceiveAsync<MqttPingRespPacket>(MqttPingReqPacket.Instance, cancellationToken).ConfigureAwait(false);
await PingAsync(cancellationToken).ConfigureAwait(false);
} }


// Wait a fixed time in all cases. Calculation of the remaining time is complicated // Wait a fixed time in all cases. Calculation of the remaining time is complicated
// due to some edge cases and was buggy in the past. Now we wait half a second because the
// due to some edge cases and was buggy in the past. Now we wait several ms because the
// min keep alive value is one second so that the server will wait 1.5 seconds for a PING // min keep alive value is one second so that the server will wait 1.5 seconds for a PING
// packet. // packet.
await Task.Delay(TimeSpan.FromMilliseconds(100), cancellationToken).ConfigureAwait(false); await Task.Delay(TimeSpan.FromMilliseconds(100), cancellationToken).ConfigureAwait(false);
@@ -538,7 +535,7 @@ namespace MQTTnet.Client
_logger.Error(exception, "Error while receiving packets."); _logger.Error(exception, "Error while receiving packets.");
} }


_packetDispatcher.Dispatch(exception);
_packetDispatcher.FailAll(exception);


if (!DisconnectIsPending()) if (!DisconnectIsPending())
{ {
@@ -555,8 +552,6 @@ namespace MQTTnet.Client
{ {
try try
{ {
_receiveTracker.Restart();

if (packet is MqttPublishPacket publishPacket) if (packet is MqttPublishPacket publishPacket)
{ {
EnqueueReceivedPublishPacket(publishPacket); EnqueueReceivedPublishPacket(publishPacket);
@@ -569,10 +564,6 @@ namespace MQTTnet.Client
{ {
await ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken).ConfigureAwait(false); await ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken).ConfigureAwait(false);
} }
else if (packet is MqttPingReqPacket)
{
await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false);
}
else if (packet is MqttDisconnectPacket disconnectPacket) else if (packet is MqttDisconnectPacket disconnectPacket)
{ {
await ProcessReceivedDisconnectPacket(disconnectPacket).ConfigureAwait(false); await ProcessReceivedDisconnectPacket(disconnectPacket).ConfigureAwait(false);
@@ -581,9 +572,20 @@ namespace MQTTnet.Client
{ {
await ProcessReceivedAuthPacket(authPacket).ConfigureAwait(false); await ProcessReceivedAuthPacket(authPacket).ConfigureAwait(false);
} }
else if (packet is MqttPingRespPacket)
{
_packetDispatcher.TryDispatch(packet);
}
else if (packet is MqttPingReqPacket)
{
throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a Client to the Server only.");
}
else else
{ {
_packetDispatcher.Dispatch(packet);
if (!_packetDispatcher.TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
}
} }
} }
catch (Exception exception) catch (Exception exception)
@@ -605,7 +607,7 @@ namespace MQTTnet.Client
_logger.Error(exception, "Error while receiving packets."); _logger.Error(exception, "Error while receiving packets.");
} }


_packetDispatcher.Dispatch(exception);
_packetDispatcher.FailAll(exception);


if (!DisconnectIsPending()) if (!DisconnectIsPending())
{ {
@@ -703,8 +705,8 @@ namespace MQTTnet.Client
_disconnectReason = (MqttClientDisconnectReason)(disconnectPacket.ReasonCode ?? MqttDisconnectReasonCode.NormalDisconnection); _disconnectReason = (MqttClientDisconnectReason)(disconnectPacket.ReasonCode ?? MqttDisconnectReasonCode.NormalDisconnection);


// Also dispatch disconnect to waiting threads to generate a proper exception. // Also dispatch disconnect to waiting threads to generate a proper exception.
_packetDispatcher.Dispatch(disconnectPacket);
_packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket));
if (!DisconnectIsPending()) if (!DisconnectIsPending())
{ {
return DisconnectInternalAsync(_packetReceiverTask, null, null); return DisconnectInternalAsync(_packetReceiverTask, null, null);


+ 2
- 0
Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs View File

@@ -5,6 +5,8 @@ namespace MQTTnet.PacketDispatcher
{ {
public interface IMqttPacketAwaiter : IDisposable public interface IMqttPacketAwaiter : IDisposable
{ {
MqttPacketAwaiterPacketFilter PacketFilter { get; }
void Complete(MqttBasePacket packet); void Complete(MqttBasePacket packet);


void Fail(Exception exception); void Fail(Exception exception);


+ 11
- 5
Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs View File

@@ -9,12 +9,16 @@ namespace MQTTnet.PacketDispatcher
public sealed class MqttPacketAwaiter<TPacket> : IMqttPacketAwaiter where TPacket : MqttBasePacket public sealed class MqttPacketAwaiter<TPacket> : IMqttPacketAwaiter where TPacket : MqttBasePacket
{ {
readonly TaskCompletionSource<MqttBasePacket> _taskCompletionSource; readonly TaskCompletionSource<MqttBasePacket> _taskCompletionSource;
readonly ushort? _packetIdentifier;
readonly MqttPacketDispatcher _owningPacketDispatcher; readonly MqttPacketDispatcher _owningPacketDispatcher;


public MqttPacketAwaiter(ushort? packetIdentifier, MqttPacketDispatcher owningPacketDispatcher)
public MqttPacketAwaiter(ushort packetIdentifier, MqttPacketDispatcher owningPacketDispatcher)
{ {
_packetIdentifier = packetIdentifier;
PacketFilter = new MqttPacketAwaiterPacketFilter
{
Type = typeof(TPacket),
Identifier = packetIdentifier
};
_owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher)); _owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher));
#if NET452 #if NET452
_taskCompletionSource = new TaskCompletionSource<MqttBasePacket>(); _taskCompletionSource = new TaskCompletionSource<MqttBasePacket>();
@@ -22,7 +26,9 @@ namespace MQTTnet.PacketDispatcher
_taskCompletionSource = new TaskCompletionSource<MqttBasePacket>(TaskCreationOptions.RunContinuationsAsynchronously); _taskCompletionSource = new TaskCompletionSource<MqttBasePacket>(TaskCreationOptions.RunContinuationsAsynchronously);
#endif #endif
} }

public MqttPacketAwaiterPacketFilter PacketFilter { get; }
public async Task<TPacket> WaitOneAsync(TimeSpan timeout) public async Task<TPacket> WaitOneAsync(TimeSpan timeout)
{ {
using (var timeoutToken = new CancellationTokenSource(timeout)) using (var timeoutToken = new CancellationTokenSource(timeout))
@@ -82,7 +88,7 @@ namespace MQTTnet.PacketDispatcher


public void Dispose() public void Dispose()
{ {
_owningPacketDispatcher.RemoveAwaiter<TPacket>(_packetIdentifier);
_owningPacketDispatcher.RemoveAwaiter(this);
} }
} }
} }

+ 11
- 0
Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs View File

@@ -0,0 +1,11 @@
using System;

namespace MQTTnet.PacketDispatcher
{
public sealed class MqttPacketAwaiterPacketFilter
{
public Type Type { get; set; }
public ushort Identifier { get; set; }
}
}

+ 58
- 60
Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs View File

@@ -1,103 +1,101 @@
using MQTTnet.Exceptions;
using MQTTnet.Packets; using MQTTnet.Packets;
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic;


namespace MQTTnet.PacketDispatcher namespace MQTTnet.PacketDispatcher
{ {
public sealed class MqttPacketDispatcher public sealed class MqttPacketDispatcher
{ {
readonly ConcurrentDictionary<Tuple<ushort, Type>, IMqttPacketAwaiter> _awaiters = new ConcurrentDictionary<Tuple<ushort, Type>, IMqttPacketAwaiter>();
readonly List<IMqttPacketAwaiter> _awaiters = new List<IMqttPacketAwaiter>();


public void Dispatch(Exception exception)
public void FailAll(Exception exception)
{ {
foreach (var awaiter in _awaiters)
if (exception == null) throw new ArgumentNullException(nameof(exception));

lock (_awaiters)
{ {
awaiter.Value.Fail(exception);
}
foreach (var awaiter in _awaiters)
{
awaiter.Fail(exception);
}


_awaiters.Clear();
_awaiters.Clear();
}
} }


public bool TryDispatch(MqttBasePacket packet)
public void CancelAll()
{ {
if (packet == null) throw new ArgumentNullException(nameof(packet));

if (packet is MqttDisconnectPacket disconnectPacket)
lock (_awaiters)
{ {
foreach (var packetAwaiter in _awaiters)
foreach (var entry in _awaiters)
{ {
packetAwaiter.Value.Fail(new MqttUnexpectedDisconnectReceivedException(disconnectPacket));
entry.Cancel();
} }


return true;
_awaiters.Clear();
} }

}
public bool TryDispatch(MqttBasePacket packet)
{
if (packet == null) throw new ArgumentNullException(nameof(packet));
ushort identifier = 0; ushort identifier = 0;
if (packet is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0)
if (packet is IMqttPacketWithIdentifier packetWithIdentifier)
{ {
identifier = packetWithIdentifier.PacketIdentifier; identifier = packetWithIdentifier.PacketIdentifier;
} }


var type = packet.GetType();
var key = new Tuple<ushort, Type>(identifier, type);
if (_awaiters.TryRemove(key, out var awaiter))
var packetType = packet.GetType();
var matchingAwaiters = new List<IMqttPacketAwaiter>();
lock (_awaiters)
{ {
awaiter.Complete(packet);
return true;
}

return false;
}

public void Dispatch(MqttBasePacket packet)
{
if (packet == null) throw new ArgumentNullException(nameof(packet));

if (!TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
for (var i = _awaiters.Count - 1; i >= 0; i--)
{
var entry = _awaiters[i];

// Note: The PingRespPacket will also arrive here and has NO identifier but there
// is code which waits for it. So the code must be able to deal with filters which
// are referring to the type only (identifier is 0)!
if (entry.PacketFilter.Type != packetType || entry.PacketFilter.Identifier != identifier)
{
continue;
}
matchingAwaiters.Add(entry);
_awaiters.RemoveAt(i);
}
} }
}

public void Cancel()
{
foreach (var awaiter in _awaiters)
foreach (var matchingEntry in matchingAwaiters)
{ {
awaiter.Value.Cancel();
matchingEntry.Complete(packet);
} }


_awaiters.Clear();
return matchingAwaiters.Count > 0;
} }
public MqttPacketAwaiter<TResponsePacket> AddAwaiter<TResponsePacket>(ushort? identifier) where TResponsePacket : MqttBasePacket
public MqttPacketAwaiter<TResponsePacket> AddAwaiter<TResponsePacket>(ushort packetIdentifier) where TResponsePacket : MqttBasePacket
{ {
if (!identifier.HasValue)
{
identifier = 0;
}

var awaiter = new MqttPacketAwaiter<TResponsePacket>(identifier, this);
var awaiter = new MqttPacketAwaiter<TResponsePacket>(packetIdentifier, this);


var key = new Tuple<ushort, Type>(identifier.Value, typeof(TResponsePacket));
if (!_awaiters.TryAdd(key, awaiter))
lock (_awaiters)
{ {
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{key.Item2.Name}' with identifier {key.Item1}.");
_awaiters.Add(awaiter);
} }
return awaiter; return awaiter;
} }


public void RemoveAwaiter<TResponsePacket>(ushort? identifier) where TResponsePacket : MqttBasePacket
public void RemoveAwaiter(IMqttPacketAwaiter awaiter)
{ {
if (!identifier.HasValue)
if (awaiter == null) throw new ArgumentNullException(nameof(awaiter));
lock (_awaiters)
{ {
identifier = 0;
_awaiters.Remove(awaiter);
} }

var key = new Tuple<ushort, Type>(identifier.Value, typeof(TResponsePacket));
_awaiters.TryRemove(key, out _);
} }
} }
} }

+ 9
- 2
Source/MQTTnet/Server/MqttClientConnection.cs View File

@@ -212,6 +212,10 @@ namespace MQTTnet.Server
{ {
await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false); await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false);
} }
else if (packet is MqttPingRespPacket)
{
throw new MqttProtocolViolationException("A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ Packet only.");
}
else if (packet is MqttDisconnectPacket) else if (packet is MqttDisconnectPacket)
{ {
Session.WillMessage = null; Session.WillMessage = null;
@@ -222,7 +226,10 @@ namespace MQTTnet.Server
} }
else else
{ {
_packetDispatcher.Dispatch(packet);
if (!_packetDispatcher.TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
}
} }
} }
} }
@@ -255,7 +262,7 @@ namespace MQTTnet.Server
Session.WillMessage = null; Session.WillMessage = null;
} }


_packetDispatcher.Cancel();
_packetDispatcher.CancelAll();


_logger.Info("Client '{0}': Connection stopped.", ClientId); _logger.Info("Client '{0}': Connection stopped.", ClientId);




Loading…
Cancel
Save