Procházet zdrojové kódy

Merge pull request #1134 from yyjdelete/issue996

Disallow to call MqttClient.ConnectAsync while Disconnect is pending.
release/3.x.x
Christian před 3 roky
committed by GitHub
rodič
revize
962622a9f7
V databázi nebyl nalezen žádný známý klíč pro tento podpis ID GPG klíče: 4AEE18F83AFDEB23
3 změnil soubory, kde provedl 80 přidání a 53 odebrání
  1. +56
    -44
      Source/MQTTnet/Client/MqttClient.cs
  2. +11
    -0
      Source/MQTTnet/Client/MqttClientConnectionStatus.cs
  3. +13
    -9
      Source/MQTTnet/Server/MqttClientSessionsManager.cs

+ 56
- 44
Source/MQTTnet/Client/MqttClient.cs Zobrazit soubor

@@ -38,8 +38,8 @@ namespace MQTTnet.Client

IMqttChannelAdapter _adapter;
bool _cleanDisconnectInitiated;
long _isDisconnectPending;
bool _isConnected;
volatile int _connectionStatus;
MqttClientDisconnectReason _disconnectReason;

DateTime _lastPacketSentTimestamp;
@@ -58,7 +58,7 @@ namespace MQTTnet.Client

public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; }

public bool IsConnected => _isConnected && Interlocked.Read(ref _isDisconnectPending) == 0;
public bool IsConnected => (MqttClientConnectionStatus)_connectionStatus == MqttClientConnectionStatus.Connected;

public IMqttClientOptions Options { get; private set; }

@@ -71,6 +71,9 @@ namespace MQTTnet.Client

ThrowIfDisposed();

if (CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connecting, MqttClientConnectionStatus.Disconnected) != MqttClientConnectionStatus.Disconnected)
throw new InvalidOperationException("Not allowed to connect while connect/disconnect is pending.");

MqttClientAuthenticateResult authenticateResult = null;

try
@@ -83,7 +86,6 @@ namespace MQTTnet.Client
_backgroundCancellationTokenSource = new CancellationTokenSource();
var backgroundCancellationToken = _backgroundCancellationTokenSource.Token;

_isDisconnectPending = 0;
var adapter = _adapterFactory.CreateClientAdapter(options);
_adapter = adapter;

@@ -108,7 +110,7 @@ namespace MQTTnet.Client
_keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(backgroundCancellationToken), backgroundCancellationToken);
}

_isConnected = true;
CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connected, MqttClientConnectionStatus.Connecting);

_logger.Info("Connected.");

@@ -126,10 +128,7 @@ namespace MQTTnet.Client

_logger.Error(exception, "Error while connecting with server.");

if (!DisconnectIsPending())
{
await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false);
}
await DisconnectInternalAsync(null, exception, authenticateResult).ConfigureAwait(false);

throw;
}
@@ -141,7 +140,8 @@ namespace MQTTnet.Client

ThrowIfDisposed();

if (DisconnectIsPending())
var clientWasConnected = IsConnected;
if (DisconnectIsPendingOrFinished())
{
return;
}
@@ -151,7 +151,7 @@ namespace MQTTnet.Client
_disconnectReason = MqttClientDisconnectReason.NormalDisconnection;
_cleanDisconnectInitiated = true;

if (_isConnected)
if (clientWasConnected)
{
var disconnectPacket = _adapter.PacketFormatterAdapter.DataConverter.CreateDisconnectPacket(options);
await SendAsync(disconnectPacket, cancellationToken).ConfigureAwait(false);
@@ -159,7 +159,7 @@ namespace MQTTnet.Client
}
finally
{
await DisconnectInternalAsync(null, null, null).ConfigureAwait(false);
await DisconnectCoreAsync(null, null, null, clientWasConnected).ConfigureAwait(false);
}
}

@@ -306,7 +306,7 @@ namespace MQTTnet.Client

void ThrowIfNotConnected()
{
if (!IsConnected || Interlocked.Read(ref _isDisconnectPending) == 1)
if (!IsConnected)
{
throw new MqttCommunicationException("The client is not connected.");
}
@@ -317,12 +317,19 @@ namespace MQTTnet.Client
if (IsConnected) throw new MqttProtocolViolationException(message);
}

async Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult)
Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult)
{
var clientWasConnected = _isConnected;
var clientWasConnected = IsConnected;
if (!DisconnectIsPendingOrFinished())
{
return DisconnectCoreAsync(sender, exception, authenticateResult, clientWasConnected);
}
return PlatformAbstractionLayer.CompletedTask;
}

async Task DisconnectCoreAsync(Task sender, Exception exception, MqttClientAuthenticateResult authenticateResult, bool clientWasConnected)
{
TryInitiateDisconnect();
_isConnected = false;

try
{
@@ -346,8 +353,6 @@ namespace MQTTnet.Client
var keepAliveTask = WaitForTaskAsync(_keepAlivePacketsSenderTask, sender);

await Task.WhenAll(receiverTask, publishPacketReceiverTask, keepAliveTask).ConfigureAwait(false);

_publishPacketReceiverQueue?.Dispose();
}
catch (Exception e)
{
@@ -357,6 +362,7 @@ namespace MQTTnet.Client
{
Cleanup();
_cleanDisconnectInitiated = false;
CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnected, MqttClientConnectionStatus.Disconnecting);

_logger.Info("Disconnected.");

@@ -478,10 +484,7 @@ namespace MQTTnet.Client
_logger.Error(exception, "Error exception while sending/receiving keep alive packets.");
}

if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false);
}
finally
{
@@ -506,10 +509,7 @@ namespace MQTTnet.Client

if (packet == null)
{
if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false);

return;
}
@@ -538,10 +538,7 @@ namespace MQTTnet.Client

_packetDispatcher.FailAll(exception);

if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
}
finally
{
@@ -610,10 +607,7 @@ namespace MQTTnet.Client

_packetDispatcher.FailAll(exception);

if (!DisconnectIsPending())
{
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
}
await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false);
}
}

@@ -715,12 +709,7 @@ namespace MQTTnet.Client
// Also dispatch disconnect to waiting threads to generate a proper exception.
_packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket));

if (!DisconnectIsPending())
{
return DisconnectInternalAsync(_packetReceiverTask, null, null);
}

return PlatformAbstractionLayer.CompletedTask;
return DisconnectInternalAsync(_packetReceiverTask, null, null);
}

Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket)
@@ -806,11 +795,34 @@ namespace MQTTnet.Client
}
}

bool DisconnectIsPending()
bool DisconnectIsPendingOrFinished()
{
var connectionStatus = (MqttClientConnectionStatus)_connectionStatus;
do
{
switch (connectionStatus)
{
case MqttClientConnectionStatus.Disconnected:
case MqttClientConnectionStatus.Disconnecting:
return true;
case MqttClientConnectionStatus.Connected:
case MqttClientConnectionStatus.Connecting:
// This will compare the _connectionStatus to old value and set it to "MqttClientConnectionStatus.Disconnecting" afterwards.
// So the first caller will get a "false" and all subsequent ones will get "true".
var curStatus = CompareExchangeConnectionStatus(MqttClientConnectionStatus.Disconnecting, connectionStatus);
if (curStatus == connectionStatus)
{
return false;
}
connectionStatus = curStatus;
break;
}
} while (true);
}

MqttClientConnectionStatus CompareExchangeConnectionStatus(MqttClientConnectionStatus value, MqttClientConnectionStatus comparand)
{
// This will read the _isDisconnectPending and set it to "1" afterwards regardless of the value.
// So the first caller will get a "false" and all subsequent ones will get "true".
return Interlocked.CompareExchange(ref _isDisconnectPending, 1, 0) != 0;
return (MqttClientConnectionStatus)Interlocked.CompareExchange(ref _connectionStatus, (int)value, (int)comparand);
}
}
}

+ 11
- 0
Source/MQTTnet/Client/MqttClientConnectionStatus.cs Zobrazit soubor

@@ -0,0 +1,11 @@
namespace MQTTnet.Client
{
public enum MqttClientConnectionStatus

{
Disconnected = 0,
Disconnecting,
Connected,
Connecting
}
}

+ 13
- 9
Source/MQTTnet/Server/MqttClientSessionsManager.cs Zobrazit soubor

@@ -21,7 +21,7 @@ namespace MQTTnet.Server
{
readonly BlockingCollection<MqttPendingApplicationMessage> _messageQueue = new BlockingCollection<MqttPendingApplicationMessage>();

readonly object _createConnectionSyncRoot = new object();
readonly AsyncLock _createConnectionSyncRoot = new AsyncLock();
readonly Dictionary<string, MqttClientConnection> _connections = new Dictionary<string, MqttClientConnection>();
readonly Dictionary<string, MqttClientSession> _sessions = new Dictionary<string, MqttClientSession>();

@@ -98,7 +98,7 @@ namespace MQTTnet.Server
return;
}

var connection = CreateClientConnection(connectPacket, connectionValidatorContext, channelAdapter);
var connection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientConnectedAsync(connectPacket.ClientId).ConfigureAwait(false);
await connection.RunAsync().ConfigureAwait(false);
}
@@ -389,9 +389,12 @@ namespace MQTTnet.Server
return context;
}

MqttClientConnection CreateClientConnection(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter)
async Task<MqttClientConnection> CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter)
{
lock (_createConnectionSyncRoot)
MqttClientConnection existingConnection;
MqttClientConnection connection;

using (await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
MqttClientSession session;
lock (_sessions)
@@ -417,8 +420,6 @@ namespace MQTTnet.Server
_sessions[connectPacket.ClientId] = session;
}

MqttClientConnection existingConnection;
MqttClientConnection connection;
lock (_connections)
{
_connections.TryGetValue(connectPacket.ClientId, out existingConnection);
@@ -427,10 +428,13 @@ namespace MQTTnet.Server
_connections[connectPacket.ClientId] = connection;
}

existingConnection?.StopAsync(MqttClientDisconnectReason.SessionTakenOver).GetAwaiter().GetResult();

return connection;
if (existingConnection != null)
{
await existingConnection.StopAsync(MqttClientDisconnectReason.SessionTakenOver).ConfigureAwait(false);
}
}

return connection;
}

async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(IMqttServerApplicationMessageInterceptor interceptor, MqttClientConnection clientConnection, MqttApplicationMessage applicationMessage)


Načítá se…
Zrušit
Uložit