diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index a163b81..d22cbe9 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -23,6 +23,8 @@ namespace MQTTnet.Server readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); readonly IMqttRetainedMessagesManager _retainedMessagesManager; + readonly Func _onStart; + readonly Func _onStop; readonly MqttClientKeepAliveMonitor _keepAliveMonitor; readonly MqttClientSessionsManager _sessionsManager; @@ -34,7 +36,7 @@ namespace MQTTnet.Server readonly string _endpoint; readonly DateTime _connectedTimestamp; - Task _packageReceiverTask; + volatile Task _packageReceiverTask; DateTime _lastPacketReceivedTimestamp; DateTime _lastNonKeepAlivePacketReceivedTimestamp; @@ -43,7 +45,7 @@ namespace MQTTnet.Server long _receivedApplicationMessagesCount; long _sentApplicationMessagesCount; - bool _isTakeover; + volatile bool _isTakeover; public MqttClientConnection( MqttConnectPacket connectPacket, @@ -52,12 +54,16 @@ namespace MQTTnet.Server IMqttServerOptions serverOptions, MqttClientSessionsManager sessionsManager, IMqttRetainedMessagesManager retainedMessagesManager, + Func onStart, + Func onStop, IMqttNetLogger logger) { Session = session ?? throw new ArgumentNullException(nameof(session)); _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); + _onStart = onStart ?? throw new ArgumentNullException(nameof(onStart)); + _onStop = onStop ?? throw new ArgumentNullException(nameof(onStop)); _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; @@ -80,15 +86,13 @@ namespace MQTTnet.Server public MqttClientSession Session { get; } - public bool IsFinalized { get; set; } - public Task StopAsync(bool isTakeover = false) { _isTakeover = isTakeover; + var task = _packageReceiverTask; StopInternal(); - var task = _packageReceiverTask; if (task != null) { return task; @@ -127,17 +131,18 @@ namespace MQTTnet.Server _cancellationToken.Dispose(); } - public Task RunAsync(MqttConnectionValidatorContext connectionValidatorContext) + public Task RunAsync(MqttConnectionValidatorContext connectionValidatorContext) { _packageReceiverTask = RunInternalAsync(connectionValidatorContext); return _packageReceiverTask; } - async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) + async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) { var disconnectType = MqttClientDisconnectType.NotClean; try { + await _onStart(); _logger.Info("Client '{0}': Session started.", ClientId); _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; @@ -241,6 +246,11 @@ namespace MQTTnet.Server } finally { + if (_isTakeover) + { + disconnectType = MqttClientDisconnectType.Takeover; + } + if (Session.WillMessage != null) { _sessionsManager.DispatchApplicationMessage(Session.WillMessage, this); @@ -255,14 +265,16 @@ namespace MQTTnet.Server _logger.Info("Client '{0}': Connection stopped.", ClientId); _packageReceiverTask = null; - } - if (_isTakeover) - { - return MqttClientDisconnectType.Takeover; + try + { + await _onStop(disconnectType); + } + catch (Exception e) + { + _logger.Error(e, "client '{0}': Error while cleaning up", ClientId); + } } - - return disconnectType; } void StopInternal() diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 1e2bcef..f98eb05 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -185,6 +185,12 @@ namespace MQTTnet.Server } var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false); + + if (!dequeueResult.IsSuccess) + { + return; + } + var queuedApplicationMessage = dequeueResult.Item; var sender = queuedApplicationMessage.Sender; @@ -235,12 +241,9 @@ namespace MQTTnet.Server async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) { - var disconnectType = MqttClientDisconnectType.NotClean; string clientId = null; - var clientWasAuthorized = false; MqttConnectPacket connectPacket; - MqttClientConnection clientConnection = null; try { try @@ -271,13 +274,17 @@ namespace MQTTnet.Server return; } - clientWasAuthorized = true; clientId = connectPacket.ClientId; - clientConnection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); - await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false); + var connection = await CreateClientConnectionAsync( + connectPacket, + connectionValidatorContext, + channelAdapter, + async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false), + async disconnectType => await CleanUpClient(clientId, channelAdapter, disconnectType) + ).ConfigureAwait(false); - disconnectType = await clientConnection.RunAsync(connectionValidatorContext).ConfigureAwait(false); + await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -286,34 +293,25 @@ namespace MQTTnet.Server { _logger.Error(exception, exception.Message); } - finally - { - if (clientWasAuthorized && disconnectType != MqttClientDisconnectType.Takeover) - { - // Only cleanup if the client was authorized. If not it will remove the existing connection, session etc. - // This allows to kill connections and sessions from known client IDs. - if (clientId != null) - { - _connections.TryRemove(clientId, out _); - - if (!_options.EnablePersistentSessions) - { - await DeleteSessionAsync(clientId).ConfigureAwait(false); - } - } - } + } - await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + private async Task CleanUpClient(string clientId, IMqttChannelAdapter channelAdapter, MqttClientDisconnectType disconnectType) + { + if (clientId != null) + { + _connections.TryRemove(clientId, out _); - if (clientWasAuthorized && clientId != null) + if (!_options.EnablePersistentSessions) { - await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + await DeleteSessionAsync(clientId).ConfigureAwait(false); } + } - if (clientConnection != null) - { - clientConnection.IsFinalized = true; - } + await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + + if (clientId != null) + { + await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); } } @@ -345,7 +343,7 @@ namespace MQTTnet.Server return context; } - async Task CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) + async Task CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter, Func onStart, Func onStop) { using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) { @@ -354,13 +352,7 @@ namespace MQTTnet.Server var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); if (isConnectionPresent) { - await existingConnection.StopAsync(true); - - // TODO: This fixes a race condition with unit test Same_Client_Id_Connect_Disconnect_Event_Order. - // It is not clear where the issue is coming from. The connected event is fired BEFORE the disconnected - // event. This is wrong. It seems that the finally block in HandleClientAsync must be finished before we - // can continue here. Maybe there is a better way to do this. - SpinWait.SpinUntil(() => existingConnection.IsFinalized, TimeSpan.FromSeconds(10)); + await existingConnection.StopAsync(true).ConfigureAwait(false); } if (isSessionPresent) @@ -383,7 +375,7 @@ namespace MQTTnet.Server _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } - var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, _logger); + var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, onStart, onStop, _logger); _connections[connection.ClientId] = connection; _sessions[session.ClientId] = session;