diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 638cf6b..fec2497 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -38,8 +38,8 @@ namespace MQTTnet.Client IMqttChannelAdapter _adapter; bool _cleanDisconnectInitiated; - long _isDisconnectPending; - bool _isConnected; + volatile int _connectionStatus; + MqttClientDisconnectReason _disconnectReason; DateTime _lastPacketSentTimestamp; @@ -58,7 +58,7 @@ namespace MQTTnet.Client public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } - public bool IsConnected => _isConnected && Interlocked.Read(ref _isDisconnectPending) == 0; + public bool IsConnected => (MqttClientConnectionStatus)_connectionStatus == MqttClientConnectionStatus.Connected; public IMqttClientOptions Options { get; private set; } @@ -71,6 +71,9 @@ namespace MQTTnet.Client ThrowIfDisposed(); + if (CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connecting, MqttClientConnectionStatus.Disconnected) != MqttClientConnectionStatus.Disconnected) + throw new InvalidOperationException("Not allowed to connect while connect/disconnect is pending."); + MqttClientAuthenticateResult authenticateResult = null; try @@ -83,7 +86,6 @@ namespace MQTTnet.Client _backgroundCancellationTokenSource = new CancellationTokenSource(); var backgroundCancellationToken = _backgroundCancellationTokenSource.Token; - _isDisconnectPending = 0; var adapter = _adapterFactory.CreateClientAdapter(options); _adapter = adapter; @@ -108,7 +110,7 @@ namespace MQTTnet.Client _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken); } - _isConnected = true; + CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connected, MqttClientConnectionStatus.Connecting); _logger.Info("Connected."); @@ -126,10 +128,7 @@ namespace MQTTnet.Client _logger.Error(exception, "Error while connecting with server."); - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false); - } + await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false); throw; } @@ -141,7 +140,8 @@ namespace MQTTnet.Client ThrowIfDisposed(); - if (DisconnectIsPending()) + var clientWasConnected = IsConnected; + if (DisconnectIsPendingOrFinished()) { return; } @@ -151,7 +151,7 @@ namespace MQTTnet.Client _disconnectReason = MqttClientDisconnectReason.NormalDisconnection; _cleanDisconnectInitiated = true; - if (_isConnected) + if (clientWasConnected) { var disconnectPacket = _adapter.PacketFormatterAdapter.DataConverter.CreateDisconnectPacket(options); await SendAsync(disconnectPacket, cancellationToken).ConfigureAwait(false); @@ -159,7 +159,7 @@ namespace MQTTnet.Client } finally { - await DisconnectInternalAsync(null, null, null).ConfigureAwait(false); + await DisconnectCoreAsync(null, null, null, clientWasConnected).ConfigureAwait(false); } } @@ -306,7 +306,7 @@ namespace MQTTnet.Client void ThrowIfNotConnected() { - if (!IsConnected || Interlocked.Read(ref _isDisconnectPending) == 1) + if (!IsConnected) { throw new MqttCommunicationException("The client is not connected."); } @@ -317,12 +317,19 @@ namespace MQTTnet.Client if (IsConnected) throw new MqttProtocolViolationException(message); } - async Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult) + Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult) { - var clientWasConnected = _isConnected; + var clientWasConnected = IsConnected; + if (!DisconnectIsPendingOrFinished()) + { + return DisconnectCoreAsync(sender, exception, authenticateResult, clientWasConnected); + } + return PlatformAbstractionLayer.CompletedTask; + } + async Task DisconnectCoreAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult, bool clientWasConnected) + { TryInitiateDisconnect(); - _isConnected = false; try { @@ -346,8 +353,6 @@ namespace MQTTnet.Client var keepAliveTask = WaitForTaskAsync(_keepAlivePacketsSenderTask, sender); await Task.WhenAll(receiverTask, publishPacketReceiverTask, keepAliveTask).ConfigureAwait(false); - - _publishPacketReceiverQueue?.Dispose(); } catch (Exception e) { @@ -357,6 +362,7 @@ namespace MQTTnet.Client { Cleanup(); _cleanDisconnectInitiated = false; + CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnected, MqttClientConnectionStatus.Disconnecting); _logger.Info("Disconnected."); @@ -478,10 +484,7 @@ namespace MQTTnet.Client _logger.Error(exception, "Error exception while sending/receiving keep alive packets."); } - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false); } finally { @@ -506,10 +509,7 @@ namespace MQTTnet.Client if (packet == null) { - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false); return; } @@ -538,10 +538,7 @@ namespace MQTTnet.Client _packetDispatcher.FailAll(exception); - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); } finally { @@ -610,10 +607,7 @@ namespace MQTTnet.Client _packetDispatcher.FailAll(exception); - if (!DisconnectIsPending()) - { - await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); - } + await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); } } @@ -715,12 +709,7 @@ namespace MQTTnet.Client // Also dispatch disconnect to waiting threads to generate a proper exception. _packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket)); - if (!DisconnectIsPending()) - { - return DisconnectInternalAsync(_packetReceiverTask, null, null); - } - - return PlatformAbstractionLayer.CompletedTask; + return DisconnectInternalAsync(_packetReceiverTask, null, null); } Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket) @@ -806,11 +795,34 @@ namespace MQTTnet.Client } } - bool DisconnectIsPending() + bool DisconnectIsPendingOrFinished() + { + var connectionStatus = (MqttClientConnectionStatus)_connectionStatus; + do + { + switch (connectionStatus) + { + case MqttClientConnectionStatus.Disconnected: + case MqttClientConnectionStatus.Disconnecting: + return true; + case MqttClientConnectionStatus.Connected: + case MqttClientConnectionStatus.Connecting: + // This will compare the _connectionStatus to old value and set it to "MqttClientConnectionStatus.Disconnecting" afterwards. + // So the first caller will get a "false" and all subsequent ones will get "true". + var curStatus = CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnecting, connectionStatus); + if (curStatus == connectionStatus) + { + return false; + } + connectionStatus = curStatus; + break; + } + } while (true); + } + + MqttClientConnectionStatus CompareExchangeConnectionStatus(MqttClientConnectionStatus value, MqttClientConnectionStatus comparand) { - // This will read the _isDisconnectPending and set it to "1" afterwards regardless of the value. - // So the first caller will get a "false" and all subsequent ones will get "true". - return Interlocked.CompareExchange(ref _isDisconnectPending, 1, 0) != 0; + return (MqttClientConnectionStatus)Interlocked.CompareExchange(ref _connectionStatus, (int)value, (int)comparand); } } } \ No newline at end of file diff --git a/Source/MQTTnet/Client/MqttClientConnectionStatus.cs b/Source/MQTTnet/Client/MqttClientConnectionStatus.cs new file mode 100644 index 0000000..2410f7d --- /dev/null +++ b/Source/MQTTnet/Client/MqttClientConnectionStatus.cs @@ -0,0 +1,11 @@ +namespace MQTTnet.Client +{ + public enum MqttClientConnectionStatus + + { + Disconnected = 0, + Disconnecting, + Connected, + Connecting + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index f0c6955..3784d42 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -21,7 +21,7 @@ namespace MQTTnet.Server { readonly BlockingCollection _messageQueue = new BlockingCollection(); - readonly object _createConnectionSyncRoot = new object(); + readonly AsyncLock _createConnectionSyncRoot = new AsyncLock(); readonly Dictionary _connections = new Dictionary(); readonly Dictionary _sessions = new Dictionary(); @@ -98,7 +98,7 @@ namespace MQTTnet.Server return; } - var connection = CreateClientConnection(connectPacket, connectionValidatorContext, channelAdapter); + var connection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); await _eventDispatcher.SafeNotifyClientConnectedAsync(connectPacket.ClientId).ConfigureAwait(false); await connection.RunAsync().ConfigureAwait(false); } @@ -389,9 +389,12 @@ namespace MQTTnet.Server return context; } - MqttClientConnection CreateClientConnection(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) + async Task CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) { - lock (_createConnectionSyncRoot) + MqttClientConnection existingConnection; + MqttClientConnection connection; + + using (await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { MqttClientSession session; lock (_sessions) @@ -417,8 +420,6 @@ namespace MQTTnet.Server _sessions[connectPacket.ClientId] = session; } - MqttClientConnection existingConnection; - MqttClientConnection connection; lock (_connections) { _connections.TryGetValue(connectPacket.ClientId, out existingConnection); @@ -427,10 +428,13 @@ namespace MQTTnet.Server _connections[connectPacket.ClientId] = connection; } - existingConnection?.StopAsync(MqttClientDisconnectReason.SessionTakenOver).GetAwaiter().GetResult(); - - return connection; + if (existingConnection != null) + { + await existingConnection.StopAsync(MqttClientDisconnectReason.SessionTakenOver).ConfigureAwait(false); + } } + + return connection; } async Task InterceptApplicationMessageAsync(IMqttServerApplicationMessageInterceptor interceptor, MqttClientConnection clientConnection, MqttApplicationMessage applicationMessage)