From c39da42ef25d694183be7f55ebd007bc5823d9b5 Mon Sep 17 00:00:00 2001 From: SilverFox Date: Thu, 1 Apr 2021 22:00:01 +0800 Subject: [PATCH 1/5] Disallow to call MqttClient.ConnectAsync while Disconnect is pending. Fix #996 Fix #1010 --- Source/MQTTnet/Client/MqttClient.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 638cf6b..3268c69 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -71,6 +71,9 @@ namespace MQTTnet.Client ThrowIfDisposed(); + if (Volatile.Read(ref _isDisconnectPending) != 0) + throw new InvalidOperationException("Not allowed to connect while disconnect is pending."); + MqttClientAuthenticateResult authenticateResult = null; try @@ -83,7 +86,6 @@ namespace MQTTnet.Client _backgroundCancellationTokenSource = new CancellationTokenSource(); var backgroundCancellationToken = _backgroundCancellationTokenSource.Token; - _isDisconnectPending = 0; var adapter = _adapterFactory.CreateClientAdapter(options); _adapter = adapter; @@ -357,6 +359,7 @@ namespace MQTTnet.Client { Cleanup(); _cleanDisconnectInitiated = false; + Volatile.Write(ref _isDisconnectPending, 0); _logger.Info("Disconnected."); From 9b10eadb241ca238bc947b595e9a4eb1a048f602 Mon Sep 17 00:00:00 2001 From: SilverFox Date: Fri, 2 Apr 2021 09:36:05 +0800 Subject: [PATCH 2/5] Fix broken tests --- Source/MQTTnet/Client/MqttClient.cs | 99 +++++++++++++++-------------- 1 file changed, 53 insertions(+), 46 deletions(-) diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 3268c69..5e30254 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -38,8 +38,11 @@ namespace MQTTnet.Client IMqttChannelAdapter _adapter; bool _cleanDisconnectInitiated; - long _isDisconnectPending; - bool _isConnected; + volatile int _connectState; + const int CONNECT_STATED_DISCONNECTED = 0; + const int CONNECT_STATED_DISCONNECTING = 1; + const int CONNECT_STATED_CONNECTED = 2; + const int CONNECT_STATED_CONNECTING = 3; MqttClientDisconnectReason _disconnectReason; DateTime _lastPacketSentTimestamp; @@ -58,7 +61,7 @@ namespace MQTTnet.Client public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } - public bool IsConnected => _isConnected && Interlocked.Read(ref _isDisconnectPending) == 0; + public bool IsConnected => _connectState == CONNECT_STATED_CONNECTED; public IMqttClientOptions Options { get; private set; } @@ -71,8 +74,8 @@ namespace MQTTnet.Client ThrowIfDisposed(); - if (Volatile.Read(ref _isDisconnectPending) != 0) - throw new InvalidOperationException("Not allowed to connect while disconnect is pending."); + if (Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_CONNECTING, CONNECT_STATED_DISCONNECTED) != CONNECT_STATED_DISCONNECTED) + throw new InvalidOperationException("Not allowed to connect while connect/disconnect is pending."); MqttClientAuthenticateResult authenticateResult = null; @@ -110,7 +113,7 @@ namespace MQTTnet.Client _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken); } - _isConnected = true; + Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_CONNECTED, CONNECT_STATED_CONNECTING); _logger.Info("Connected."); @@ -128,10 +131,7 @@ namespace MQTTnet.Client _logger.Error(exception, "Error while connecting with server."); - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false); - } + await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false); throw; } @@ -143,7 +143,8 @@ namespace MQTTnet.Client ThrowIfDisposed(); - if (DisconnectIsPending()) + var clientWasConnected = IsConnected; + if (DisconnectIsPendingOrFinished()) { return; } @@ -153,7 +154,7 @@ namespace MQTTnet.Client _disconnectReason = MqttClientDisconnectReason.NormalDisconnection; _cleanDisconnectInitiated = true; - if (_isConnected) + if (clientWasConnected) { var disconnectPacket = _adapter.PacketFormatterAdapter.DataConverter.CreateDisconnectPacket(options); await SendAsync(disconnectPacket, cancellationToken).ConfigureAwait(false); @@ -161,7 +162,7 @@ namespace MQTTnet.Client } finally { - await DisconnectInternalAsync(null, null, null).ConfigureAwait(false); + await DisconnectCoreAsync(null, null, null, clientWasConnected).ConfigureAwait(false); } } @@ -308,7 +309,7 @@ namespace MQTTnet.Client void ThrowIfNotConnected() { - if (!IsConnected || Interlocked.Read(ref _isDisconnectPending) == 1) + if (!IsConnected) { throw new MqttCommunicationException("The client is not connected."); } @@ -319,12 +320,19 @@ namespace MQTTnet.Client if (IsConnected) throw new MqttProtocolViolationException(message); } - async Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult) + Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult) { - var clientWasConnected = _isConnected; + var clientWasConnected = IsConnected; + if (!DisconnectIsPendingOrFinished()) + { + return DisconnectCoreAsync(sender, exception, authenticateResult, clientWasConnected); + } + return PlatformAbstractionLayer.CompletedTask; + } + async Task DisconnectCoreAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult, bool clientWasConnected) + { TryInitiateDisconnect(); - _isConnected = false; try { @@ -348,8 +356,6 @@ namespace MQTTnet.Client var keepAliveTask = WaitForTaskAsync(_keepAlivePacketsSenderTask, sender); await Task.WhenAll(receiverTask, publishPacketReceiverTask, keepAliveTask).ConfigureAwait(false); - - _publishPacketReceiverQueue?.Dispose(); } catch (Exception e) { @@ -359,7 +365,7 @@ namespace MQTTnet.Client { Cleanup(); _cleanDisconnectInitiated = false; - Volatile.Write(ref _isDisconnectPending, 0); + Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_DISCONNECTED, CONNECT_STATED_DISCONNECTING); _logger.Info("Disconnected."); @@ -481,10 +487,7 @@ namespace MQTTnet.Client _logger.Error(exception, "Error exception while sending/receiving keep alive packets."); } - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false); } finally { @@ -509,10 +512,7 @@ namespace MQTTnet.Client if (packet == null) { - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false); return; } @@ -541,10 +541,7 @@ namespace MQTTnet.Client _packetDispatcher.FailAll(exception); - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); } finally { @@ -613,10 +610,7 @@ namespace MQTTnet.Client _packetDispatcher.FailAll(exception); - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); } } @@ -718,12 +712,7 @@ namespace MQTTnet.Client // Also dispatch disconnect to waiting threads to generate a proper exception. _packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket)); - if (!DisconnectIsPending()) - { - return DisconnectInternalAsync(_packetReceiverTask, null, null); - } - - return PlatformAbstractionLayer.CompletedTask; + return DisconnectInternalAsync(_packetReceiverTask, null, null); } Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket) @@ -809,11 +798,29 @@ namespace MQTTnet.Client } } - bool DisconnectIsPending() + bool DisconnectIsPendingOrFinished() { - // This will read the _isDisconnectPending and set it to "1" afterwards regardless of the value. - // So the first caller will get a "false" and all subsequent ones will get "true". - return Interlocked.CompareExchange(ref _isDisconnectPending, 1, 0) != 0; + var connectState = _connectState; + do + { + switch (connectState) + { + case CONNECT_STATED_DISCONNECTING: + case CONNECT_STATED_DISCONNECTED: + return true; + case CONNECT_STATED_CONNECTING: + case CONNECT_STATED_CONNECTED: + // This will compare the _connectState to old value and set it to "CONNECT_STATED_DISCONNECTING" afterwards. + // So the first caller will get a "false" and all subsequent ones will get "true". + var newState = Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_DISCONNECTING, connectState); + if (newState != connectState) + { + return false; + } + connectState = newState; + break; + } + } while (true); } } } \ No newline at end of file From f99daee4d1acc4853af19ef2b6ca2c556fbe28c7 Mon Sep 17 00:00:00 2001 From: SilverFox Date: Fri, 2 Apr 2021 17:53:53 +0800 Subject: [PATCH 3/5] Fix Session_Tests.Manage_Session_MaxParallel --- .../Server/MqttClientSessionsManager.cs | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 2c3202b..bec93db 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -1,4 +1,4 @@ -using MQTTnet.Adapter; +using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Formatter; @@ -21,7 +21,7 @@ namespace MQTTnet.Server { readonly BlockingCollection _messageQueue = new BlockingCollection(); - readonly object _createConnectionSyncRoot = new object(); + readonly AsyncLock _createConnectionSyncRoot = new AsyncLock(); readonly Dictionary _connections = new Dictionary(); readonly Dictionary _sessions = new Dictionary(); @@ -98,7 +98,7 @@ namespace MQTTnet.Server return; } - var connection = CreateClientConnection(connectPacket, connectionValidatorContext, channelAdapter); + var connection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); await _eventDispatcher.SafeNotifyClientConnectedAsync(connectPacket.ClientId).ConfigureAwait(false); await connection.RunAsync().ConfigureAwait(false); } @@ -387,9 +387,12 @@ namespace MQTTnet.Server return context; } - MqttClientConnection CreateClientConnection(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) + async Task CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) { - lock (_createConnectionSyncRoot) + MqttClientConnection existingConnection; + MqttClientConnection connection; + + using (await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { MqttClientSession session; lock (_sessions) @@ -415,8 +418,6 @@ namespace MQTTnet.Server _sessions[connectPacket.ClientId] = session; } - MqttClientConnection existingConnection; - MqttClientConnection connection; lock (_connections) { _connections.TryGetValue(connectPacket.ClientId, out existingConnection); @@ -425,10 +426,13 @@ namespace MQTTnet.Server _connections[connectPacket.ClientId] = connection; } - existingConnection?.StopAsync(MqttClientDisconnectReason.SessionTakenOver).GetAwaiter().GetResult(); - - return connection; + if (existingConnection != null) + { + await existingConnection.StopAsync(MqttClientDisconnectReason.SessionTakenOver).ConfigureAwait(false); + } } + + return connection; } async Task InterceptApplicationMessageAsync(IMqttServerApplicationMessageInterceptor interceptor, MqttClientConnection clientConnection, MqttApplicationMessage applicationMessage) From d913d4402ef14cb1a575223e7b7acd480d8dac1e Mon Sep 17 00:00:00 2001 From: SilverFox Date: Fri, 2 Apr 2021 19:47:06 +0800 Subject: [PATCH 4/5] Convert MqttClient._connectState to enum --- Source/MQTTnet/Client/MqttClient.cs | 40 ++++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 5e30254..ddc97e2 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -39,10 +39,15 @@ namespace MQTTnet.Client IMqttChannelAdapter _adapter; bool _cleanDisconnectInitiated; volatile int _connectState; - const int CONNECT_STATED_DISCONNECTED = 0; - const int CONNECT_STATED_DISCONNECTING = 1; - const int CONNECT_STATED_CONNECTED = 2; - const int CONNECT_STATED_CONNECTING = 3; + + enum ConnectState + { + Disconnected = 0, + Disconnecting, + Connected, + Connecting + } + MqttClientDisconnectReason _disconnectReason; DateTime _lastPacketSentTimestamp; @@ -61,7 +66,7 @@ namespace MQTTnet.Client public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } - public bool IsConnected => _connectState == CONNECT_STATED_CONNECTED; + public bool IsConnected => (ConnectState)_connectState == ConnectState.Connected; public IMqttClientOptions Options { get; private set; } @@ -74,7 +79,7 @@ namespace MQTTnet.Client ThrowIfDisposed(); - if (Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_CONNECTING, CONNECT_STATED_DISCONNECTED) != CONNECT_STATED_DISCONNECTED) + if (CompareExchangeConnectState(ConnectState.Connecting, ConnectState.Disconnected) != ConnectState.Disconnected) throw new InvalidOperationException("Not allowed to connect while connect/disconnect is pending."); MqttClientAuthenticateResult authenticateResult = null; @@ -113,7 +118,7 @@ namespace MQTTnet.Client _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken); } - Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_CONNECTED, CONNECT_STATED_CONNECTING); + CompareExchangeConnectState(ConnectState.Connected, ConnectState.Connecting); _logger.Info("Connected."); @@ -365,7 +370,7 @@ namespace MQTTnet.Client { Cleanup(); _cleanDisconnectInitiated = false; - Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_DISCONNECTED, CONNECT_STATED_DISCONNECTING); + CompareExchangeConnectState(ConnectState.Disconnected, ConnectState.Disconnecting); _logger.Info("Disconnected."); @@ -800,20 +805,20 @@ namespace MQTTnet.Client bool DisconnectIsPendingOrFinished() { - var connectState = _connectState; + var connectState = (ConnectState)_connectState; do { switch (connectState) { - case CONNECT_STATED_DISCONNECTING: - case CONNECT_STATED_DISCONNECTED: + case ConnectState.Disconnected: + case ConnectState.Disconnecting: return true; - case CONNECT_STATED_CONNECTING: - case CONNECT_STATED_CONNECTED: + case ConnectState.Connected: + case ConnectState.Connecting: // This will compare the _connectState to old value and set it to "CONNECT_STATED_DISCONNECTING" afterwards. // So the first caller will get a "false" and all subsequent ones will get "true". - var newState = Interlocked.CompareExchange(ref _connectState, CONNECT_STATED_DISCONNECTING, connectState); - if (newState != connectState) + var newState = CompareExchangeConnectState(ConnectState.Disconnecting, connectState); + if (newState == connectState) { return false; } @@ -822,5 +827,10 @@ namespace MQTTnet.Client } } while (true); } + + ConnectState CompareExchangeConnectState(ConnectState value, ConnectState comparand) + { + return (ConnectState)Interlocked.CompareExchange(ref _connectState, (int)value, (int)comparand); + } } } \ No newline at end of file From 899ed7b1da1a644d1992a69584473f480a0caecb Mon Sep 17 00:00:00 2001 From: SilverFox Date: Fri, 2 Apr 2021 22:38:20 +0800 Subject: [PATCH 5/5] Rename MqttClient.ConnectState to MqttClientConnectionStatus --- Source/MQTTnet/Client/MqttClient.cs | 42 ++++++++----------- .../Client/MqttClientConnectionStatus.cs | 11 +++++ 2 files changed, 28 insertions(+), 25 deletions(-) create mode 100644 Source/MQTTnet/Client/MqttClientConnectionStatus.cs diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index ddc97e2..fec2497 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -38,15 +38,7 @@ namespace MQTTnet.Client IMqttChannelAdapter _adapter; bool _cleanDisconnectInitiated; - volatile int _connectState; - - enum ConnectState - { - Disconnected = 0, - Disconnecting, - Connected, - Connecting - } + volatile int _connectionStatus; MqttClientDisconnectReason _disconnectReason; @@ -66,7 +58,7 @@ namespace MQTTnet.Client public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } - public bool IsConnected => (ConnectState)_connectState == ConnectState.Connected; + public bool IsConnected => (MqttClientConnectionStatus)_connectionStatus == MqttClientConnectionStatus.Connected; public IMqttClientOptions Options { get; private set; } @@ -79,7 +71,7 @@ namespace MQTTnet.Client ThrowIfDisposed(); - if (CompareExchangeConnectState(ConnectState.Connecting, ConnectState.Disconnected) != ConnectState.Disconnected) + if (CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connecting, MqttClientConnectionStatus.Disconnected) != MqttClientConnectionStatus.Disconnected) throw new InvalidOperationException("Not allowed to connect while connect/disconnect is pending."); MqttClientAuthenticateResult authenticateResult = null; @@ -118,7 +110,7 @@ namespace MQTTnet.Client _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken); } - CompareExchangeConnectState(ConnectState.Connected, ConnectState.Connecting); + CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connected, MqttClientConnectionStatus.Connecting); _logger.Info("Connected."); @@ -370,7 +362,7 @@ namespace MQTTnet.Client { Cleanup(); _cleanDisconnectInitiated = false; - CompareExchangeConnectState(ConnectState.Disconnected, ConnectState.Disconnecting); + CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnected, MqttClientConnectionStatus.Disconnecting); _logger.Info("Disconnected."); @@ -805,32 +797,32 @@ namespace MQTTnet.Client bool DisconnectIsPendingOrFinished() { - var connectState = (ConnectState)_connectState; + var connectionStatus = (MqttClientConnectionStatus)_connectionStatus; do { - switch (connectState) + switch (connectionStatus) { - case ConnectState.Disconnected: - case ConnectState.Disconnecting: + case MqttClientConnectionStatus.Disconnected: + case MqttClientConnectionStatus.Disconnecting: return true; - case ConnectState.Connected: - case ConnectState.Connecting: - // This will compare the _connectState to old value and set it to "CONNECT_STATED_DISCONNECTING" afterwards. + case MqttClientConnectionStatus.Connected: + case MqttClientConnectionStatus.Connecting: + // This will compare the _connectionStatus to old value and set it to "MqttClientConnectionStatus.Disconnecting" afterwards. // So the first caller will get a "false" and all subsequent ones will get "true". - var newState = CompareExchangeConnectState(ConnectState.Disconnecting, connectState); - if (newState == connectState) + var curStatus = CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnecting, connectionStatus); + if (curStatus == connectionStatus) { return false; } - connectState = newState; + connectionStatus = curStatus; break; } } while (true); } - ConnectState CompareExchangeConnectState(ConnectState value, ConnectState comparand) + MqttClientConnectionStatus CompareExchangeConnectionStatus(MqttClientConnectionStatus value, MqttClientConnectionStatus comparand) { - return (ConnectState)Interlocked.CompareExchange(ref _connectState, (int)value, (int)comparand); + return (MqttClientConnectionStatus)Interlocked.CompareExchange(ref _connectionStatus, (int)value, (int)comparand); } } } \ No newline at end of file diff --git a/Source/MQTTnet/Client/MqttClientConnectionStatus.cs b/Source/MQTTnet/Client/MqttClientConnectionStatus.cs new file mode 100644 index 0000000..2410f7d --- /dev/null +++ b/Source/MQTTnet/Client/MqttClientConnectionStatus.cs @@ -0,0 +1,11 @@ +namespace MQTTnet.Client +{ + public enum MqttClientConnectionStatus + + { + Disconnected = 0, + Disconnecting, + Connected, + Connecting + } +} \ No newline at end of file