Pārlūkot izejas kodu

Refactor keep alive message handling in order to fix issues with sudden disconnects.

release/3.x.x
Christian pirms 3 gadiem
vecāks
revīzija
b766c1bebb
7 mainītis faili ar 123 papildinājumiem un 96 dzēšanām
  1. +1
    -0
      MQTTnet.sln.DotSettings
  2. +31
    -29
      Source/MQTTnet/Client/MqttClient.cs
  3. +2
    -0
      Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs
  4. +11
    -5
      Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs
  5. +11
    -0
      Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs
  6. +58
    -60
      Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs
  7. +9
    -2
      Source/MQTTnet/Server/MqttClientConnection.cs

+ 1
- 0
MQTTnet.sln.DotSettings Parādīt failu

@@ -12,4 +12,5 @@
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002ECSharpPlaceAttributeOnSameLineMigration/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateBlankLinesAroundFieldToBlankLinesAroundProperty/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateThisQualifierSettings/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=PINGREQ/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unsub/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

+ 31
- 29
Source/MQTTnet/Client/MqttClient.cs Parādīt failu

@@ -14,7 +14,6 @@ using MQTTnet.PacketDispatcher;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Implementations;
@@ -25,8 +24,6 @@ namespace MQTTnet.Client
{
readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
readonly Stopwatch _sendTracker = new Stopwatch();
readonly Stopwatch _receiveTracker = new Stopwatch();
readonly object _disconnectLock = new object();

readonly IMqttClientAdapterFactory _adapterFactory;
@@ -44,6 +41,8 @@ namespace MQTTnet.Client
long _isDisconnectPending;
bool _isConnected;
MqttClientDisconnectReason _disconnectReason;
DateTime _lastPacketSentTimestamp;

public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger)
{
@@ -79,7 +78,7 @@ namespace MQTTnet.Client
Options = options;

_packetIdentifierProvider.Reset();
_packetDispatcher.Cancel();
_packetDispatcher.CancelAll();

_backgroundCancellationTokenSource = new CancellationTokenSource();
var backgroundCancellationToken = _backgroundCancellationTokenSource.Token;
@@ -102,8 +101,7 @@ namespace MQTTnet.Client
authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, combined.Token).ConfigureAwait(false);
}

_sendTracker.Restart();
_receiveTracker.Restart();
_lastPacketSentTimestamp = DateTime.UtcNow;

if (Options.KeepAlivePeriod != TimeSpan.Zero)
{
@@ -391,8 +389,8 @@ namespace MQTTnet.Client
{
cancellationToken.ThrowIfCancellationRequested();

_sendTracker.Restart();
_lastPacketSentTimestamp = DateTime.UtcNow;
return _adapter.SendPacketAsync(packet, cancellationToken);
}

@@ -400,18 +398,17 @@ namespace MQTTnet.Client
{
cancellationToken.ThrowIfCancellationRequested();

ushort identifier = 0;
if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0)
ushort packetIdentifier = 0;
if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier)
{
identifier = packetWithIdentifier.PacketIdentifier;
packetIdentifier = packetWithIdentifier.PacketIdentifier;
}

using (var packetAwaiter = _packetDispatcher.AddAwaiter<TResponsePacket>(identifier))
using (var packetAwaiter = _packetDispatcher.AddAwaiter<TResponsePacket>(packetIdentifier))
{
try
{
_sendTracker.Restart();
await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false);
await SendAsync(requestPacket, cancellationToken).ConfigureAwait(false);
}
catch (Exception exception)
{
@@ -446,15 +443,15 @@ namespace MQTTnet.Client
while (!cancellationToken.IsCancellationRequested)
{
// Values described here: [MQTT-3.1.2-24].
var waitTime = keepAlivePeriod - _sendTracker.Elapsed;
var timeWithoutPacketSent = DateTime.UtcNow - _lastPacketSentTimestamp;

if (waitTime <= TimeSpan.Zero)
if (timeWithoutPacketSent > keepAlivePeriod)
{
await SendAndReceiveAsync<MqttPingRespPacket>(MqttPingReqPacket.Instance, cancellationToken).ConfigureAwait(false);
await PingAsync(cancellationToken).ConfigureAwait(false);
}

// Wait a fixed time in all cases. Calculation of the remaining time is complicated
// due to some edge cases and was buggy in the past. Now we wait half a second because the
// due to some edge cases and was buggy in the past. Now we wait several ms because the
// min keep alive value is one second so that the server will wait 1.5 seconds for a PING
// packet.
await Task.Delay(TimeSpan.FromMilliseconds(100), cancellationToken).ConfigureAwait(false);
@@ -538,7 +535,7 @@ namespace MQTTnet.Client
_logger.Error(exception, "Error while receiving packets.");
}

_packetDispatcher.Dispatch(exception);
_packetDispatcher.FailAll(exception);

if (!DisconnectIsPending())
{
@@ -555,8 +552,6 @@ namespace MQTTnet.Client
{
try
{
_receiveTracker.Restart();

if (packet is MqttPublishPacket publishPacket)
{
EnqueueReceivedPublishPacket(publishPacket);
@@ -569,10 +564,6 @@ namespace MQTTnet.Client
{
await ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken).ConfigureAwait(false);
}
else if (packet is MqttPingReqPacket)
{
await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false);
}
else if (packet is MqttDisconnectPacket disconnectPacket)
{
await ProcessReceivedDisconnectPacket(disconnectPacket).ConfigureAwait(false);
@@ -581,9 +572,20 @@ namespace MQTTnet.Client
{
await ProcessReceivedAuthPacket(authPacket).ConfigureAwait(false);
}
else if (packet is MqttPingRespPacket)
{
_packetDispatcher.TryDispatch(packet);
}
else if (packet is MqttPingReqPacket)
{
throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a Client to the Server only.");
}
else
{
_packetDispatcher.Dispatch(packet);
if (!_packetDispatcher.TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
}
}
}
catch (Exception exception)
@@ -605,7 +607,7 @@ namespace MQTTnet.Client
_logger.Error(exception, "Error while receiving packets.");
}

_packetDispatcher.Dispatch(exception);
_packetDispatcher.FailAll(exception);

if (!DisconnectIsPending())
{
@@ -703,8 +705,8 @@ namespace MQTTnet.Client
_disconnectReason = (MqttClientDisconnectReason)(disconnectPacket.ReasonCode ?? MqttDisconnectReasonCode.NormalDisconnection);

// Also dispatch disconnect to waiting threads to generate a proper exception.
_packetDispatcher.Dispatch(disconnectPacket);
_packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket));
if (!DisconnectIsPending())
{
return DisconnectInternalAsync(_packetReceiverTask, null, null);


+ 2
- 0
Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs Parādīt failu

@@ -5,6 +5,8 @@ namespace MQTTnet.PacketDispatcher
{
public interface IMqttPacketAwaiter : IDisposable
{
MqttPacketAwaiterPacketFilter PacketFilter { get; }
void Complete(MqttBasePacket packet);

void Fail(Exception exception);


+ 11
- 5
Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs Parādīt failu

@@ -9,12 +9,16 @@ namespace MQTTnet.PacketDispatcher
public sealed class MqttPacketAwaiter<TPacket> : IMqttPacketAwaiter where TPacket : MqttBasePacket
{
readonly TaskCompletionSource<MqttBasePacket> _taskCompletionSource;
readonly ushort? _packetIdentifier;
readonly MqttPacketDispatcher _owningPacketDispatcher;

public MqttPacketAwaiter(ushort? packetIdentifier, MqttPacketDispatcher owningPacketDispatcher)
public MqttPacketAwaiter(ushort packetIdentifier, MqttPacketDispatcher owningPacketDispatcher)
{
_packetIdentifier = packetIdentifier;
PacketFilter = new MqttPacketAwaiterPacketFilter
{
Type = typeof(TPacket),
Identifier = packetIdentifier
};
_owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher));
#if NET452
_taskCompletionSource = new TaskCompletionSource<MqttBasePacket>();
@@ -22,7 +26,9 @@ namespace MQTTnet.PacketDispatcher
_taskCompletionSource = new TaskCompletionSource<MqttBasePacket>(TaskCreationOptions.RunContinuationsAsynchronously);
#endif
}

public MqttPacketAwaiterPacketFilter PacketFilter { get; }
public async Task<TPacket> WaitOneAsync(TimeSpan timeout)
{
using (var timeoutToken = new CancellationTokenSource(timeout))
@@ -82,7 +88,7 @@ namespace MQTTnet.PacketDispatcher

public void Dispose()
{
_owningPacketDispatcher.RemoveAwaiter<TPacket>(_packetIdentifier);
_owningPacketDispatcher.RemoveAwaiter(this);
}
}
}

+ 11
- 0
Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs Parādīt failu

@@ -0,0 +1,11 @@
using System;

namespace MQTTnet.PacketDispatcher
{
public sealed class MqttPacketAwaiterPacketFilter
{
public Type Type { get; set; }
public ushort Identifier { get; set; }
}
}

+ 58
- 60
Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs Parādīt failu

@@ -1,103 +1,101 @@
using MQTTnet.Exceptions;
using MQTTnet.Packets;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;

namespace MQTTnet.PacketDispatcher
{
public sealed class MqttPacketDispatcher
{
readonly ConcurrentDictionary<Tuple<ushort, Type>, IMqttPacketAwaiter> _awaiters = new ConcurrentDictionary<Tuple<ushort, Type>, IMqttPacketAwaiter>();
readonly List<IMqttPacketAwaiter> _awaiters = new List<IMqttPacketAwaiter>();

public void Dispatch(Exception exception)
public void FailAll(Exception exception)
{
foreach (var awaiter in _awaiters)
if (exception == null) throw new ArgumentNullException(nameof(exception));

lock (_awaiters)
{
awaiter.Value.Fail(exception);
}
foreach (var awaiter in _awaiters)
{
awaiter.Fail(exception);
}

_awaiters.Clear();
_awaiters.Clear();
}
}

public bool TryDispatch(MqttBasePacket packet)
public void CancelAll()
{
if (packet == null) throw new ArgumentNullException(nameof(packet));

if (packet is MqttDisconnectPacket disconnectPacket)
lock (_awaiters)
{
foreach (var packetAwaiter in _awaiters)
foreach (var entry in _awaiters)
{
packetAwaiter.Value.Fail(new MqttUnexpectedDisconnectReceivedException(disconnectPacket));
entry.Cancel();
}

return true;
_awaiters.Clear();
}

}
public bool TryDispatch(MqttBasePacket packet)
{
if (packet == null) throw new ArgumentNullException(nameof(packet));
ushort identifier = 0;
if (packet is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0)
if (packet is IMqttPacketWithIdentifier packetWithIdentifier)
{
identifier = packetWithIdentifier.PacketIdentifier;
}

var type = packet.GetType();
var key = new Tuple<ushort, Type>(identifier, type);
if (_awaiters.TryRemove(key, out var awaiter))
var packetType = packet.GetType();
var matchingAwaiters = new List<IMqttPacketAwaiter>();
lock (_awaiters)
{
awaiter.Complete(packet);
return true;
}

return false;
}

public void Dispatch(MqttBasePacket packet)
{
if (packet == null) throw new ArgumentNullException(nameof(packet));

if (!TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
for (var i = _awaiters.Count - 1; i >= 0; i--)
{
var entry = _awaiters[i];

// Note: The PingRespPacket will also arrive here and has NO identifier but there
// is code which waits for it. So the code must be able to deal with filters which
// are referring to the type only (identifier is 0)!
if (entry.PacketFilter.Type != packetType || entry.PacketFilter.Identifier != identifier)
{
continue;
}
matchingAwaiters.Add(entry);
_awaiters.RemoveAt(i);
}
}
}

public void Cancel()
{
foreach (var awaiter in _awaiters)
foreach (var matchingEntry in matchingAwaiters)
{
awaiter.Value.Cancel();
matchingEntry.Complete(packet);
}

_awaiters.Clear();
return matchingAwaiters.Count > 0;
}
public MqttPacketAwaiter<TResponsePacket> AddAwaiter<TResponsePacket>(ushort? identifier) where TResponsePacket : MqttBasePacket
public MqttPacketAwaiter<TResponsePacket> AddAwaiter<TResponsePacket>(ushort packetIdentifier) where TResponsePacket : MqttBasePacket
{
if (!identifier.HasValue)
{
identifier = 0;
}

var awaiter = new MqttPacketAwaiter<TResponsePacket>(identifier, this);
var awaiter = new MqttPacketAwaiter<TResponsePacket>(packetIdentifier, this);

var key = new Tuple<ushort, Type>(identifier.Value, typeof(TResponsePacket));
if (!_awaiters.TryAdd(key, awaiter))
lock (_awaiters)
{
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{key.Item2.Name}' with identifier {key.Item1}.");
_awaiters.Add(awaiter);
}
return awaiter;
}

public void RemoveAwaiter<TResponsePacket>(ushort? identifier) where TResponsePacket : MqttBasePacket
public void RemoveAwaiter(IMqttPacketAwaiter awaiter)
{
if (!identifier.HasValue)
if (awaiter == null) throw new ArgumentNullException(nameof(awaiter));
lock (_awaiters)
{
identifier = 0;
_awaiters.Remove(awaiter);
}

var key = new Tuple<ushort, Type>(identifier.Value, typeof(TResponsePacket));
_awaiters.TryRemove(key, out _);
}
}
}

+ 9
- 2
Source/MQTTnet/Server/MqttClientConnection.cs Parādīt failu

@@ -212,6 +212,10 @@ namespace MQTTnet.Server
{
await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false);
}
else if (packet is MqttPingRespPacket)
{
throw new MqttProtocolViolationException("A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ Packet only.");
}
else if (packet is MqttDisconnectPacket)
{
Session.WillMessage = null;
@@ -222,7 +226,10 @@ namespace MQTTnet.Server
}
else
{
_packetDispatcher.Dispatch(packet);
if (!_packetDispatcher.TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
}
}
}
}
@@ -255,7 +262,7 @@ namespace MQTTnet.Server
Session.WillMessage = null;
}

_packetDispatcher.Cancel();
_packetDispatcher.CancelAll();

_logger.Info("Client '{0}': Connection stopped.", ClientId);



Notiek ielāde…
Atcelt
Saglabāt