diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index 70beba8..a163b81 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -3,6 +3,7 @@ using MQTTnet.Client; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Formatter; +using MQTTnet.Implementations; using MQTTnet.Internal; using MQTTnet.PacketDispatcher; using MQTTnet.Packets; @@ -17,30 +18,32 @@ namespace MQTTnet.Server { public sealed class MqttClientConnection : IDisposable { - private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); - private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); - private readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); + readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); + readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); + readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); - private readonly IMqttRetainedMessagesManager _retainedMessagesManager; - private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; - private readonly MqttClientSessionsManager _sessionsManager; + readonly IMqttRetainedMessagesManager _retainedMessagesManager; + readonly MqttClientKeepAliveMonitor _keepAliveMonitor; + readonly MqttClientSessionsManager _sessionsManager; - private readonly IMqttNetLogger _logger; - private readonly IMqttServerOptions _serverOptions; + readonly IMqttNetLogger _logger; + readonly IMqttServerOptions _serverOptions; - private readonly IMqttChannelAdapter _channelAdapter; - private readonly IMqttDataConverter _dataConverter; - private readonly string _endpoint; - private readonly DateTime _connectedTimestamp; + readonly IMqttChannelAdapter _channelAdapter; + readonly IMqttDataConverter _dataConverter; + readonly string _endpoint; + readonly DateTime _connectedTimestamp; - private Task _packageReceiverTask; - private DateTime _lastPacketReceivedTimestamp; - private DateTime _lastNonKeepAlivePacketReceivedTimestamp; + Task _packageReceiverTask; + DateTime _lastPacketReceivedTimestamp; + DateTime _lastNonKeepAlivePacketReceivedTimestamp; - private long _receivedPacketsCount; - private long _sentPacketsCount = 1; // Start with 1 because the CONNECT packet is not counted anywhere. - private long _receivedApplicationMessagesCount; - private long _sentApplicationMessagesCount; + long _receivedPacketsCount; + long _sentPacketsCount = 1; // Start with 1 because the CONNECT packet is not counted anywhere. + long _receivedApplicationMessagesCount; + long _sentApplicationMessagesCount; + + bool _isTakeover; public MqttClientConnection( MqttConnectPacket connectPacket, @@ -64,7 +67,7 @@ namespace MQTTnet.Server if (logger == null) throw new ArgumentNullException(nameof(logger)); _logger = logger.CreateChildLogger(nameof(MqttClientConnection)); - _keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, StopAsync, _logger); + _keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, () => StopAsync(), _logger); _connectedTimestamp = DateTime.UtcNow; _lastPacketReceivedTimestamp = _connectedTimestamp; @@ -77,15 +80,21 @@ namespace MQTTnet.Server public MqttClientSession Session { get; } - public async Task StopAsync() + public bool IsFinalized { get; set; } + + public Task StopAsync(bool isTakeover = false) { + _isTakeover = isTakeover; + StopInternal(); var task = _packageReceiverTask; if (task != null) { - await task.ConfigureAwait(false); + return task; } + + return PlatformAbstractionLayer.CompletedTask; } public void ResetStatistics() @@ -243,11 +252,16 @@ namespace MQTTnet.Server _channelAdapter.ReadingPacketStartedCallback = null; _channelAdapter.ReadingPacketCompletedCallback = null; - _logger.Info("Client '{0}': Session stopped.", ClientId); + _logger.Info("Client '{0}': Connection stopped.", ClientId); _packageReceiverTask = null; } + if (_isTakeover) + { + return MqttClientDisconnectType.Takeover; + } + return disconnectType; } @@ -319,7 +333,7 @@ namespace MQTTnet.Server _sessionsManager.DispatchApplicationMessage(applicationMessage, this); - return Task.FromResult(0); + return PlatformAbstractionLayer.CompletedTask; } Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) @@ -422,9 +436,6 @@ namespace MQTTnet.Server } _logger.Verbose("Queued application message sent (ClientId: {0}).", ClientId); - - // TODO: - //Interlocked.Increment(ref _sentPacketsCount); } } catch (Exception exception) diff --git a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs index 9d17552..6b28825 100644 --- a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs +++ b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs @@ -9,13 +9,13 @@ namespace MQTTnet.Server { public class MqttClientKeepAliveMonitor { - private readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch(); + readonly Stopwatch _lastPacketReceivedTracker = new Stopwatch(); - private readonly string _clientId; - private readonly Func _keepAliveElapsedCallback; - private readonly IMqttNetLogger _logger; + readonly string _clientId; + readonly Func _keepAliveElapsedCallback; + readonly IMqttNetLogger _logger; - private bool _isPaused; + bool _isPaused; public MqttClientKeepAliveMonitor(string clientId, Func keepAliveElapsedCallback, IMqttNetLogger logger) { @@ -51,7 +51,7 @@ namespace MQTTnet.Server _lastPacketReceivedTracker.Restart(); } - private async Task RunAsync(int keepAlivePeriod, CancellationToken cancellationToken) + async Task RunAsync(int keepAlivePeriod, CancellationToken cancellationToken) { try { diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 08568d8..1e2bcef 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -16,19 +16,19 @@ namespace MQTTnet.Server { public class MqttClientSessionsManager : Disposable { - private readonly AsyncQueue _messageQueue = new AsyncQueue(); + readonly AsyncQueue _messageQueue = new AsyncQueue(); - private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1); - private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); - private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); - private readonly IDictionary _serverSessionItems = new ConcurrentDictionary(); + readonly AsyncLock _createConnectionGate = new AsyncLock(); + readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); + readonly IDictionary _serverSessionItems = new ConcurrentDictionary(); - private readonly CancellationToken _cancellationToken; - private readonly MqttServerEventDispatcher _eventDispatcher; + readonly CancellationToken _cancellationToken; + readonly MqttServerEventDispatcher _eventDispatcher; - private readonly IMqttRetainedMessagesManager _retainedMessagesManager; - private readonly IMqttServerOptions _options; - private readonly IMqttNetLogger _logger; + readonly IMqttRetainedMessagesManager _retainedMessagesManager; + readonly IMqttServerOptions _options; + readonly IMqttNetLogger _logger; public MqttClientSessionsManager( IMqttServerOptions options, @@ -60,9 +60,11 @@ namespace MQTTnet.Server } } - public Task HandleClientAsync(IMqttChannelAdapter clientAdapter) + public Task HandleClientConnectionAsync(IMqttChannelAdapter clientAdapter) { - return HandleClientAsync(clientAdapter, _cancellationToken); + if (clientAdapter is null) throw new ArgumentNullException(nameof(clientAdapter)); + + return HandleClientConnectionAsync(clientAdapter, _cancellationToken); } public Task> GetClientStatusAsync() @@ -155,7 +157,7 @@ namespace MQTTnet.Server base.Dispose(disposing); } - private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken) + async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken) { while (!cancellationToken.IsCancellationRequested) { @@ -173,7 +175,7 @@ namespace MQTTnet.Server } } - private async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken) + async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken) { try { @@ -231,14 +233,14 @@ namespace MQTTnet.Server } } - private async Task HandleClientAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) + async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) { var disconnectType = MqttClientDisconnectType.NotClean; string clientId = null; - var clientWasConnected = true; + var clientWasAuthorized = false; MqttConnectPacket connectPacket; - + MqttClientConnection clientConnection = null; try { try @@ -261,7 +263,6 @@ namespace MQTTnet.Server if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { - clientWasConnected = false; // Send failure response here without preparing a session. The result for a successful connect // will be sent from the session itself. var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); @@ -270,12 +271,13 @@ namespace MQTTnet.Server return; } + clientWasAuthorized = true; clientId = connectPacket.ClientId; - var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); + clientConnection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false); - disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); + disconnectType = await clientConnection.RunAsync(connectionValidatorContext).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -286,8 +288,10 @@ namespace MQTTnet.Server } finally { - if (clientWasConnected) + if (clientWasAuthorized && disconnectType != MqttClientDisconnectType.Takeover) { + // Only cleanup if the client was authorized. If not it will remove the existing connection, session etc. + // This allows to kill connections and sessions from known client IDs. if (clientId != null) { _connections.TryRemove(clientId, out _); @@ -297,18 +301,23 @@ namespace MQTTnet.Server await DeleteSessionAsync(clientId).ConfigureAwait(false); } } + } - await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false); - if (clientId != null) - { - await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); - } + if (clientWasAuthorized && clientId != null) + { + await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + } + + if (clientConnection != null) + { + clientConnection.IsFinalized = true; } } } - private async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) + async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) { var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary()); @@ -336,17 +345,22 @@ namespace MQTTnet.Server return context; } - private async Task CreateConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) + async Task CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) { - await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false); - try + using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) { var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session); var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); if (isConnectionPresent) { - await existingConnection.StopAsync().ConfigureAwait(false); + await existingConnection.StopAsync(true); + + // TODO: This fixes a race condition with unit test Same_Client_Id_Connect_Disconnect_Event_Order. + // It is not clear where the issue is coming from. The connected event is fired BEFORE the disconnected + // event. This is wrong. It seems that the finally block in HandleClientAsync must be finished before we + // can continue here. Maybe there is a better way to do this. + SpinWait.SpinUntil(() => existingConnection.IsFinalized, TimeSpan.FromSeconds(10)); } if (isSessionPresent) @@ -376,13 +390,9 @@ namespace MQTTnet.Server return connection; } - finally - { - _createConnectionGate.Release(); - } } - private async Task InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage) + async Task InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage) { var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) @@ -410,7 +420,7 @@ namespace MQTTnet.Server return interceptorContext; } - private async Task TryCleanupChannelAsync(IMqttChannelAdapter channelAdapter) + async Task SafeCleanupChannelAsync(IMqttChannelAdapter channelAdapter) { try { diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index 0fa6b4c..ea687eb 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -1,15 +1,15 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Adapter; +using MQTTnet.Adapter; using MQTTnet.Client.Publishing; using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Protocol; using MQTTnet.Server.Status; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet.Server { @@ -192,7 +192,7 @@ namespace MQTTnet.Server private Task OnHandleClient(IMqttChannelAdapter channelAdapter) { - return _clientSessionsManager.HandleClientAsync(channelAdapter); + return _clientSessionsManager.HandleClientConnectionAsync(channelAdapter); } } }