Browse Source

Fix connect and disconnect event flow.

release/3.x.x
Christian Kratky 4 years ago
parent
commit
d464bfb7c5
4 changed files with 98 additions and 77 deletions
  1. +38
    -27
      Source/MQTTnet/Server/MqttClientConnection.cs
  2. +6
    -6
      Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs
  3. +47
    -37
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  4. +7
    -7
      Source/MQTTnet/Server/MqttServer.cs

+ 38
- 27
Source/MQTTnet/Server/MqttClientConnection.cs View File

@@ -3,6 +3,7 @@ using MQTTnet.Client;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Exceptions; using MQTTnet.Exceptions;
using MQTTnet.Formatter; using MQTTnet.Formatter;
using MQTTnet.Implementations;
using MQTTnet.Internal; using MQTTnet.Internal;
using MQTTnet.PacketDispatcher; using MQTTnet.PacketDispatcher;
using MQTTnet.Packets; using MQTTnet.Packets;
@@ -17,30 +18,32 @@ namespace MQTTnet.Server
{ {
public sealed class MqttClientConnection : IDisposable public sealed class MqttClientConnection : IDisposable
{ {
private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
private readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource();
readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource();


private readonly IMqttRetainedMessagesManager _retainedMessagesManager;
private readonly MqttClientKeepAliveMonitor _keepAliveMonitor;
private readonly MqttClientSessionsManager _sessionsManager;
readonly IMqttRetainedMessagesManager _retainedMessagesManager;
readonly MqttClientKeepAliveMonitor _keepAliveMonitor;
readonly MqttClientSessionsManager _sessionsManager;


private readonly IMqttNetLogger _logger;
private readonly IMqttServerOptions _serverOptions;
readonly IMqttNetLogger _logger;
readonly IMqttServerOptions _serverOptions;


private readonly IMqttChannelAdapter _channelAdapter;
private readonly IMqttDataConverter _dataConverter;
private readonly string _endpoint;
private readonly DateTime _connectedTimestamp;
readonly IMqttChannelAdapter _channelAdapter;
readonly IMqttDataConverter _dataConverter;
readonly string _endpoint;
readonly DateTime _connectedTimestamp;


private Task<MqttClientDisconnectType> _packageReceiverTask;
private DateTime _lastPacketReceivedTimestamp;
private DateTime _lastNonKeepAlivePacketReceivedTimestamp;
Task<MqttClientDisconnectType> _packageReceiverTask;
DateTime _lastPacketReceivedTimestamp;
DateTime _lastNonKeepAlivePacketReceivedTimestamp;


private long _receivedPacketsCount;
private long _sentPacketsCount = 1; // Start with 1 because the CONNECT packet is not counted anywhere.
private long _receivedApplicationMessagesCount;
private long _sentApplicationMessagesCount;
long _receivedPacketsCount;
long _sentPacketsCount = 1; // Start with 1 because the CONNECT packet is not counted anywhere.
long _receivedApplicationMessagesCount;
long _sentApplicationMessagesCount;

bool _isTakeover;


public MqttClientConnection( public MqttClientConnection(
MqttConnectPacket connectPacket, MqttConnectPacket connectPacket,
@@ -64,7 +67,7 @@ namespace MQTTnet.Server
if (logger == null) throw new ArgumentNullException(nameof(logger)); if (logger == null) throw new ArgumentNullException(nameof(logger));
_logger = logger.CreateChildLogger(nameof(MqttClientConnection)); _logger = logger.CreateChildLogger(nameof(MqttClientConnection));


_keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, StopAsync, _logger);
_keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, () => StopAsync(), _logger);


_connectedTimestamp = DateTime.UtcNow; _connectedTimestamp = DateTime.UtcNow;
_lastPacketReceivedTimestamp = _connectedTimestamp; _lastPacketReceivedTimestamp = _connectedTimestamp;
@@ -77,15 +80,21 @@ namespace MQTTnet.Server


public MqttClientSession Session { get; } public MqttClientSession Session { get; }


public async Task StopAsync()
public bool IsFinalized { get; set; }

public Task StopAsync(bool isTakeover = false)
{ {
_isTakeover = isTakeover;

StopInternal(); StopInternal();


var task = _packageReceiverTask; var task = _packageReceiverTask;
if (task != null) if (task != null)
{ {
await task.ConfigureAwait(false);
return task;
} }

return PlatformAbstractionLayer.CompletedTask;
} }


public void ResetStatistics() public void ResetStatistics()
@@ -243,11 +252,16 @@ namespace MQTTnet.Server
_channelAdapter.ReadingPacketStartedCallback = null; _channelAdapter.ReadingPacketStartedCallback = null;
_channelAdapter.ReadingPacketCompletedCallback = null; _channelAdapter.ReadingPacketCompletedCallback = null;


_logger.Info("Client '{0}': Session stopped.", ClientId);
_logger.Info("Client '{0}': Connection stopped.", ClientId);


_packageReceiverTask = null; _packageReceiverTask = null;
} }


if (_isTakeover)
{
return MqttClientDisconnectType.Takeover;
}

return disconnectType; return disconnectType;
} }


@@ -319,7 +333,7 @@ namespace MQTTnet.Server


_sessionsManager.DispatchApplicationMessage(applicationMessage, this); _sessionsManager.DispatchApplicationMessage(applicationMessage, this);


return Task.FromResult(0);
return PlatformAbstractionLayer.CompletedTask;
} }


Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket)
@@ -422,9 +436,6 @@ namespace MQTTnet.Server
} }


_logger.Verbose("Queued application message sent (ClientId: {0}).", ClientId); _logger.Verbose("Queued application message sent (ClientId: {0}).", ClientId);

// TODO:
//Interlocked.Increment(ref _sentPacketsCount);
} }
} }
catch (Exception exception) catch (Exception exception)


+ 6
- 6
Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs View File

@@ -9,13 +9,13 @@ namespace MQTTnet.Server
{ {
public class MqttClientKeepAliveMonitor public class MqttClientKeepAliveMonitor
{ {
private readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch();
readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch();


private readonly string _clientId;
private readonly Func<Task> _keepAliveElapsedCallback;
private readonly IMqttNetLogger _logger;
readonly string _clientId;
readonly Func<Task> _keepAliveElapsedCallback;
readonly IMqttNetLogger _logger;


private bool _isPaused;
bool _isPaused;


public MqttClientKeepAliveMonitor(string clientId, Func<Task> keepAliveElapsedCallback, IMqttNetLogger logger) public MqttClientKeepAliveMonitor(string clientId, Func<Task> keepAliveElapsedCallback, IMqttNetLogger logger)
{ {
@@ -51,7 +51,7 @@ namespace MQTTnet.Server
_lastPacketReceivedTracker.Restart(); _lastPacketReceivedTracker.Restart();
} }


private async Task RunAsync(int keepAlivePeriod, CancellationToken cancellationToken)
async Task RunAsync(int keepAlivePeriod, CancellationToken cancellationToken)
{ {
try try
{ {


+ 47
- 37
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -16,19 +16,19 @@ namespace MQTTnet.Server
{ {
public class MqttClientSessionsManager : Disposable public class MqttClientSessionsManager : Disposable
{ {
private readonly AsyncQueue<MqttEnqueuedApplicationMessage> _messageQueue = new AsyncQueue<MqttEnqueuedApplicationMessage>();
readonly AsyncQueue<MqttEnqueuedApplicationMessage> _messageQueue = new AsyncQueue<MqttEnqueuedApplicationMessage>();


private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1);
private readonly ConcurrentDictionary<string, MqttClientConnection> _connections = new ConcurrentDictionary<string, MqttClientConnection>();
private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>();
private readonly IDictionary<object, object> _serverSessionItems = new ConcurrentDictionary<object, object>();
readonly AsyncLock _createConnectionGate = new AsyncLock();
readonly ConcurrentDictionary<string, MqttClientConnection> _connections = new ConcurrentDictionary<string, MqttClientConnection>();
readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>();
readonly IDictionary<object, object> _serverSessionItems = new ConcurrentDictionary<object, object>();


private readonly CancellationToken _cancellationToken;
private readonly MqttServerEventDispatcher _eventDispatcher;
readonly CancellationToken _cancellationToken;
readonly MqttServerEventDispatcher _eventDispatcher;


private readonly IMqttRetainedMessagesManager _retainedMessagesManager;
private readonly IMqttServerOptions _options;
private readonly IMqttNetLogger _logger;
readonly IMqttRetainedMessagesManager _retainedMessagesManager;
readonly IMqttServerOptions _options;
readonly IMqttNetLogger _logger;


public MqttClientSessionsManager( public MqttClientSessionsManager(
IMqttServerOptions options, IMqttServerOptions options,
@@ -60,9 +60,11 @@ namespace MQTTnet.Server
} }
} }


public Task HandleClientAsync(IMqttChannelAdapter clientAdapter)
public Task HandleClientConnectionAsync(IMqttChannelAdapter clientAdapter)
{ {
return HandleClientAsync(clientAdapter, _cancellationToken);
if (clientAdapter is null) throw new ArgumentNullException(nameof(clientAdapter));

return HandleClientConnectionAsync(clientAdapter, _cancellationToken);
} }


public Task<IList<IMqttClientStatus>> GetClientStatusAsync() public Task<IList<IMqttClientStatus>> GetClientStatusAsync()
@@ -155,7 +157,7 @@ namespace MQTTnet.Server
base.Dispose(disposing); base.Dispose(disposing);
} }


private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken)
async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken)
{ {
while (!cancellationToken.IsCancellationRequested) while (!cancellationToken.IsCancellationRequested)
{ {
@@ -173,7 +175,7 @@ namespace MQTTnet.Server
} }
} }


private async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken)
async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken)
{ {
try try
{ {
@@ -231,14 +233,14 @@ namespace MQTTnet.Server
} }
} }


private async Task HandleClientAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{ {
var disconnectType = MqttClientDisconnectType.NotClean; var disconnectType = MqttClientDisconnectType.NotClean;
string clientId = null; string clientId = null;
var clientWasConnected = true;
var clientWasAuthorized = false;


MqttConnectPacket connectPacket; MqttConnectPacket connectPacket;
MqttClientConnection clientConnection = null;
try try
{ {
try try
@@ -261,7 +263,6 @@ namespace MQTTnet.Server


if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success)
{ {
clientWasConnected = false;
// Send failure response here without preparing a session. The result for a successful connect // Send failure response here without preparing a session. The result for a successful connect
// will be sent from the session itself. // will be sent from the session itself.
var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext);
@@ -270,12 +271,13 @@ namespace MQTTnet.Server
return; return;
} }


clientWasAuthorized = true;
clientId = connectPacket.ClientId; clientId = connectPacket.ClientId;
var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false);
clientConnection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false);


await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false); await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false);


disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false);
disconnectType = await clientConnection.RunAsync(connectionValidatorContext).ConfigureAwait(false);
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
@@ -286,8 +288,10 @@ namespace MQTTnet.Server
} }
finally finally
{ {
if (clientWasConnected)
if (clientWasAuthorized && disconnectType != MqttClientDisconnectType.Takeover)
{ {
// Only cleanup if the client was authorized. If not it will remove the existing connection, session etc.
// This allows to kill connections and sessions from known client IDs.
if (clientId != null) if (clientId != null)
{ {
_connections.TryRemove(clientId, out _); _connections.TryRemove(clientId, out _);
@@ -297,18 +301,23 @@ namespace MQTTnet.Server
await DeleteSessionAsync(clientId).ConfigureAwait(false); await DeleteSessionAsync(clientId).ConfigureAwait(false);
} }
} }
}


await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false);
await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false);


if (clientId != null)
{
await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
}
if (clientWasAuthorized && clientId != null)
{
await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
}

if (clientConnection != null)
{
clientConnection.IsFinalized = true;
} }
} }
} }


private async Task<MqttConnectionValidatorContext> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
async Task<MqttConnectionValidatorContext> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
{ {
var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary<object, object>()); var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary<object, object>());


@@ -336,17 +345,22 @@ namespace MQTTnet.Server
return context; return context;
} }


private async Task<MqttClientConnection> CreateConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter)
async Task<MqttClientConnection> CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter)
{ {
await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false);
try
using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false))
{ {
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session); var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session);


var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection);
if (isConnectionPresent) if (isConnectionPresent)
{ {
await existingConnection.StopAsync().ConfigureAwait(false);
await existingConnection.StopAsync(true);

// TODO: This fixes a race condition with unit test Same_Client_Id_Connect_Disconnect_Event_Order.
// It is not clear where the issue is coming from. The connected event is fired BEFORE the disconnected
// event. This is wrong. It seems that the finally block in HandleClientAsync must be finished before we
// can continue here. Maybe there is a better way to do this.
SpinWait.SpinUntil(() => existingConnection.IsFinalized, TimeSpan.FromSeconds(10));
} }


if (isSessionPresent) if (isSessionPresent)
@@ -376,13 +390,9 @@ namespace MQTTnet.Server


return connection; return connection;
} }
finally
{
_createConnectionGate.Release();
}
} }


private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage)
async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage)
{ {
var interceptor = _options.ApplicationMessageInterceptor; var interceptor = _options.ApplicationMessageInterceptor;
if (interceptor == null) if (interceptor == null)
@@ -410,7 +420,7 @@ namespace MQTTnet.Server
return interceptorContext; return interceptorContext;
} }


private async Task TryCleanupChannelAsync(IMqttChannelAdapter channelAdapter)
async Task SafeCleanupChannelAsync(IMqttChannelAdapter channelAdapter)
{ {
try try
{ {


+ 7
- 7
Source/MQTTnet/Server/MqttServer.cs View File

@@ -1,15 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Adapter;
using MQTTnet.Adapter;
using MQTTnet.Client.Publishing; using MQTTnet.Client.Publishing;
using MQTTnet.Client.Receiving; using MQTTnet.Client.Receiving;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Exceptions; using MQTTnet.Exceptions;
using MQTTnet.Protocol; using MQTTnet.Protocol;
using MQTTnet.Server.Status; using MQTTnet.Server.Status;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
@@ -192,7 +192,7 @@ namespace MQTTnet.Server


private Task OnHandleClient(IMqttChannelAdapter channelAdapter) private Task OnHandleClient(IMqttChannelAdapter channelAdapter)
{ {
return _clientSessionsManager.HandleClientAsync(channelAdapter);
return _clientSessionsManager.HandleClientConnectionAsync(channelAdapter);
} }
} }
} }

Loading…
Cancel
Save