Browse Source

Merge pull request #1134 from yyjdelete/issue996

Disallow to call MqttClient.ConnectAsync while Disconnect is pending.
release/3.x.x
Christian 3 years ago
committed by GitHub
parent
commit
962622a9f7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 53 deletions
  1. +56
    -44
      Source/MQTTnet/Client/MqttClient.cs
  2. +11
    -0
      Source/MQTTnet/Client/MqttClientConnectionStatus.cs
  3. +13
    -9
      Source/MQTTnet/Server/MqttClientSessionsManager.cs

+ 56
- 44
Source/MQTTnet/Client/MqttClient.cs View File

@@ -38,8 +38,8 @@ namespace MQTTnet.Client


IMqttChannelAdapter _adapter; IMqttChannelAdapter _adapter;
bool _cleanDisconnectInitiated; bool _cleanDisconnectInitiated;
long _isDisconnectPending;
bool _isConnected;
volatile int _connectionStatus;
MqttClientDisconnectReason _disconnectReason; MqttClientDisconnectReason _disconnectReason;


DateTime _lastPacketSentTimestamp; DateTime _lastPacketSentTimestamp;
@@ -58,7 +58,7 @@ namespace MQTTnet.Client


public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; }


public bool IsConnected => _isConnected && Interlocked.Read(ref _isDisconnectPending) == 0;
public bool IsConnected => (MqttClientConnectionStatus)_connectionStatus == MqttClientConnectionStatus.Connected;


public IMqttClientOptions Options { get; private set; } public IMqttClientOptions Options { get; private set; }


@@ -71,6 +71,9 @@ namespace MQTTnet.Client


ThrowIfDisposed(); ThrowIfDisposed();


if (CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connecting, MqttClientConnectionStatus.Disconnected) != MqttClientConnectionStatus.Disconnected)
throw new InvalidOperationException("Not allowed to connect while connect/disconnect is pending.");

MqttClientAuthenticateResult authenticateResult = null; MqttClientAuthenticateResult authenticateResult = null;


try try
@@ -83,7 +86,6 @@ namespace MQTTnet.Client
_backgroundCancellationTokenSource = new CancellationTokenSource(); _backgroundCancellationTokenSource = new CancellationTokenSource();
var backgroundCancellationToken = _backgroundCancellationTokenSource.Token; var backgroundCancellationToken = _backgroundCancellationTokenSource.Token;


_isDisconnectPending = 0;
var adapter = _adapterFactory.CreateClientAdapter(options); var adapter = _adapterFactory.CreateClientAdapter(options);
_adapter = adapter; _adapter = adapter;


@@ -108,7 +110,7 @@ namespace MQTTnet.Client
_keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken); _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken);
} }


_isConnected = true;
CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connected, MqttClientConnectionStatus.Connecting);


_logger.Info("Connected."); _logger.Info("Connected.");


@@ -126,10 +128,7 @@ namespace MQTTnet.Client


_logger.Error(exception, "Error while connecting with server."); _logger.Error(exception, "Error while connecting with server.");


if (!DisconnectIsPending())
{
await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false);
}
await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false);


throw; throw;
} }
@@ -141,7 +140,8 @@ namespace MQTTnet.Client


ThrowIfDisposed(); ThrowIfDisposed();


if (DisconnectIsPending())
var clientWasConnected = IsConnected;
if (DisconnectIsPendingOrFinished())
{ {
return; return;
} }
@@ -151,7 +151,7 @@ namespace MQTTnet.Client
_disconnectReason = MqttClientDisconnectReason.NormalDisconnection; _disconnectReason = MqttClientDisconnectReason.NormalDisconnection;
_cleanDisconnectInitiated = true; _cleanDisconnectInitiated = true;


if (_isConnected)
if (clientWasConnected)
{ {
var disconnectPacket = _adapter.PacketFormatterAdapter.DataConverter.CreateDisconnectPacket(options); var disconnectPacket = _adapter.PacketFormatterAdapter.DataConverter.CreateDisconnectPacket(options);
await SendAsync(disconnectPacket, cancellationToken).ConfigureAwait(false); await SendAsync(disconnectPacket, cancellationToken).ConfigureAwait(false);
@@ -159,7 +159,7 @@ namespace MQTTnet.Client
} }
finally finally
{ {
await DisconnectInternalAsync(null, null, null).ConfigureAwait(false);
await DisconnectCoreAsync(null, null, null, clientWasConnected).ConfigureAwait(false);
} }
} }


@@ -306,7 +306,7 @@ namespace MQTTnet.Client


void ThrowIfNotConnected() void ThrowIfNotConnected()
{ {
if (!IsConnected || Interlocked.Read(ref _isDisconnectPending) == 1)
if (!IsConnected)
{ {
throw new MqttCommunicationException("The client is not connected."); throw new MqttCommunicationException("The client is not connected.");
} }
@@ -317,12 +317,19 @@ namespace MQTTnet.Client
if (IsConnected) throw new MqttProtocolViolationException(message); if (IsConnected) throw new MqttProtocolViolationException(message);
} }


async Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult)
Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult)
{ {
var clientWasConnected = _isConnected;
var clientWasConnected = IsConnected;
if (!DisconnectIsPendingOrFinished())
{
return DisconnectCoreAsync(sender, exception, authenticateResult, clientWasConnected);
}
return PlatformAbstractionLayer.CompletedTask;
}


async Task DisconnectCoreAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult, bool clientWasConnected)
{
TryInitiateDisconnect(); TryInitiateDisconnect();
_isConnected = false;


try try
{ {
@@ -346,8 +353,6 @@ namespace MQTTnet.Client
var keepAliveTask = WaitForTaskAsync(_keepAlivePacketsSenderTask, sender); var keepAliveTask = WaitForTaskAsync(_keepAlivePacketsSenderTask, sender);


await Task.WhenAll(receiverTask, publishPacketReceiverTask, keepAliveTask).ConfigureAwait(false); await Task.WhenAll(receiverTask, publishPacketReceiverTask, keepAliveTask).ConfigureAwait(false);

_publishPacketReceiverQueue?.Dispose();
} }
catch (Exception e) catch (Exception e)
{ {
@@ -357,6 +362,7 @@ namespace MQTTnet.Client
{ {
Cleanup(); Cleanup();
_cleanDisconnectInitiated = false; _cleanDisconnectInitiated = false;
CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnected, MqttClientConnectionStatus.Disconnecting);


_logger.Info("Disconnected."); _logger.Info("Disconnected.");


@@ -478,10 +484,7 @@ namespace MQTTnet.Client
_logger.Error(exception, "Error exception while sending/receiving keep alive packets."); _logger.Error(exception, "Error exception while sending/receiving keep alive packets.");
} }


if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false);
} }
finally finally
{ {
@@ -506,10 +509,7 @@ namespace MQTTnet.Client


if (packet == null) if (packet == null)
{ {
if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false);


return; return;
} }
@@ -538,10 +538,7 @@ namespace MQTTnet.Client


_packetDispatcher.FailAll(exception); _packetDispatcher.FailAll(exception);


if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
} }
finally finally
{ {
@@ -610,10 +607,7 @@ namespace MQTTnet.Client


_packetDispatcher.FailAll(exception); _packetDispatcher.FailAll(exception);


if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
} }
} }


@@ -715,12 +709,7 @@ namespace MQTTnet.Client
// Also dispatch disconnect to waiting threads to generate a proper exception. // Also dispatch disconnect to waiting threads to generate a proper exception.
_packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket)); _packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket));


if (!DisconnectIsPending())
{
return DisconnectInternalAsync(_packetReceiverTask, null, null);
}

return PlatformAbstractionLayer.CompletedTask;
return DisconnectInternalAsync(_packetReceiverTask, null, null);
} }


Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket) Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket)
@@ -806,11 +795,34 @@ namespace MQTTnet.Client
} }
} }


bool DisconnectIsPending()
bool DisconnectIsPendingOrFinished()
{
var connectionStatus = (MqttClientConnectionStatus)_connectionStatus;
do
{
switch (connectionStatus)
{
case MqttClientConnectionStatus.Disconnected:
case MqttClientConnectionStatus.Disconnecting:
return true;
case MqttClientConnectionStatus.Connected:
case MqttClientConnectionStatus.Connecting:
// This will compare the _connectionStatus to old value and set it to "MqttClientConnectionStatus.Disconnecting" afterwards.
// So the first caller will get a "false" and all subsequent ones will get "true".
var curStatus = CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnecting, connectionStatus);
if (curStatus == connectionStatus)
{
return false;
}
connectionStatus = curStatus;
break;
}
} while (true);
}

MqttClientConnectionStatus CompareExchangeConnectionStatus(MqttClientConnectionStatus value, MqttClientConnectionStatus comparand)
{ {
// This will read the _isDisconnectPending and set it to "1" afterwards regardless of the value.
// So the first caller will get a "false" and all subsequent ones will get "true".
return Interlocked.CompareExchange(ref _isDisconnectPending, 1, 0) != 0;
return (MqttClientConnectionStatus)Interlocked.CompareExchange(ref _connectionStatus, (int)value, (int)comparand);
} }
} }
} }

+ 11
- 0
Source/MQTTnet/Client/MqttClientConnectionStatus.cs View File

@@ -0,0 +1,11 @@
namespace MQTTnet.Client
{
public enum MqttClientConnectionStatus

{
Disconnected = 0,
Disconnecting,
Connected,
Connecting
}
}

+ 13
- 9
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -21,7 +21,7 @@ namespace MQTTnet.Server
{ {
readonly BlockingCollection<MqttPendingApplicationMessage> _messageQueue = new BlockingCollection<MqttPendingApplicationMessage>(); readonly BlockingCollection<MqttPendingApplicationMessage> _messageQueue = new BlockingCollection<MqttPendingApplicationMessage>();


readonly object _createConnectionSyncRoot = new object();
readonly AsyncLock _createConnectionSyncRoot = new AsyncLock();
readonly Dictionary<string, MqttClientConnection> _connections = new Dictionary<string, MqttClientConnection>(); readonly Dictionary<string, MqttClientConnection> _connections = new Dictionary<string, MqttClientConnection>();
readonly Dictionary<string, MqttClientSession> _sessions = new Dictionary<string, MqttClientSession>(); readonly Dictionary<string, MqttClientSession> _sessions = new Dictionary<string, MqttClientSession>();


@@ -98,7 +98,7 @@ namespace MQTTnet.Server
return; return;
} }


var connection = CreateClientConnection(connectPacket, connectionValidatorContext, channelAdapter);
var connection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientConnectedAsync(connectPacket.ClientId).ConfigureAwait(false); await _eventDispatcher.SafeNotifyClientConnectedAsync(connectPacket.ClientId).ConfigureAwait(false);
await connection.RunAsync().ConfigureAwait(false); await connection.RunAsync().ConfigureAwait(false);
} }
@@ -389,9 +389,12 @@ namespace MQTTnet.Server
return context; return context;
} }


MqttClientConnection CreateClientConnection(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter)
async Task<MqttClientConnection> CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter)
{ {
lock (_createConnectionSyncRoot)
MqttClientConnection existingConnection;
MqttClientConnection connection;

using (await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
MqttClientSession session; MqttClientSession session;
lock (_sessions) lock (_sessions)
@@ -417,8 +420,6 @@ namespace MQTTnet.Server
_sessions[connectPacket.ClientId] = session; _sessions[connectPacket.ClientId] = session;
} }


MqttClientConnection existingConnection;
MqttClientConnection connection;
lock (_connections) lock (_connections)
{ {
_connections.TryGetValue(connectPacket.ClientId, out existingConnection); _connections.TryGetValue(connectPacket.ClientId, out existingConnection);
@@ -427,10 +428,13 @@ namespace MQTTnet.Server
_connections[connectPacket.ClientId] = connection; _connections[connectPacket.ClientId] = connection;
} }


existingConnection?.StopAsync(MqttClientDisconnectReason.SessionTakenOver).GetAwaiter().GetResult();

return connection;
if (existingConnection != null)
{
await existingConnection.StopAsync(MqttClientDisconnectReason.SessionTakenOver).ConfigureAwait(false);
}
} }

return connection;
} }


async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(IMqttServerApplicationMessageInterceptor interceptor, MqttClientConnection clientConnection, MqttApplicationMessage applicationMessage) async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(IMqttServerApplicationMessageInterceptor interceptor, MqttClientConnection clientConnection, MqttApplicationMessage applicationMessage)


Loading…
Cancel
Save