This failing test was caused by a raise condition in the MqttClientSessionsManager The test failed as the value of `flow` in the assertion `Assert.AreEqual("cdc", flow);` was "ccd"release/3.x.x
@@ -23,6 +23,8 @@ namespace MQTTnet.Server | |||||
readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); | readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); | ||||
readonly IMqttRetainedMessagesManager _retainedMessagesManager; | readonly IMqttRetainedMessagesManager _retainedMessagesManager; | ||||
readonly Func<Task> _onStart; | |||||
readonly Func<MqttClientDisconnectType, Task> _onStop; | |||||
readonly MqttClientKeepAliveMonitor _keepAliveMonitor; | readonly MqttClientKeepAliveMonitor _keepAliveMonitor; | ||||
readonly MqttClientSessionsManager _sessionsManager; | readonly MqttClientSessionsManager _sessionsManager; | ||||
@@ -34,7 +36,7 @@ namespace MQTTnet.Server | |||||
readonly string _endpoint; | readonly string _endpoint; | ||||
readonly DateTime _connectedTimestamp; | readonly DateTime _connectedTimestamp; | ||||
Task<MqttClientDisconnectType> _packageReceiverTask; | |||||
volatile Task _packageReceiverTask; | |||||
DateTime _lastPacketReceivedTimestamp; | DateTime _lastPacketReceivedTimestamp; | ||||
DateTime _lastNonKeepAlivePacketReceivedTimestamp; | DateTime _lastNonKeepAlivePacketReceivedTimestamp; | ||||
@@ -43,7 +45,7 @@ namespace MQTTnet.Server | |||||
long _receivedApplicationMessagesCount; | long _receivedApplicationMessagesCount; | ||||
long _sentApplicationMessagesCount; | long _sentApplicationMessagesCount; | ||||
bool _isTakeover; | |||||
volatile bool _isTakeover; | |||||
public MqttClientConnection( | public MqttClientConnection( | ||||
MqttConnectPacket connectPacket, | MqttConnectPacket connectPacket, | ||||
@@ -52,12 +54,16 @@ namespace MQTTnet.Server | |||||
IMqttServerOptions serverOptions, | IMqttServerOptions serverOptions, | ||||
MqttClientSessionsManager sessionsManager, | MqttClientSessionsManager sessionsManager, | ||||
IMqttRetainedMessagesManager retainedMessagesManager, | IMqttRetainedMessagesManager retainedMessagesManager, | ||||
Func<Task> onStart, | |||||
Func<MqttClientDisconnectType, Task> onStop, | |||||
IMqttNetLogger logger) | IMqttNetLogger logger) | ||||
{ | { | ||||
Session = session ?? throw new ArgumentNullException(nameof(session)); | Session = session ?? throw new ArgumentNullException(nameof(session)); | ||||
_serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); | _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); | ||||
_sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); | _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); | ||||
_retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); | _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); | ||||
_onStart = onStart ?? throw new ArgumentNullException(nameof(onStart)); | |||||
_onStop = onStop ?? throw new ArgumentNullException(nameof(onStop)); | |||||
_channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); | _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); | ||||
_dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; | _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; | ||||
@@ -80,15 +86,13 @@ namespace MQTTnet.Server | |||||
public MqttClientSession Session { get; } | public MqttClientSession Session { get; } | ||||
public bool IsFinalized { get; set; } | |||||
public Task StopAsync(bool isTakeover = false) | public Task StopAsync(bool isTakeover = false) | ||||
{ | { | ||||
_isTakeover = isTakeover; | _isTakeover = isTakeover; | ||||
var task = _packageReceiverTask; | |||||
StopInternal(); | StopInternal(); | ||||
var task = _packageReceiverTask; | |||||
if (task != null) | if (task != null) | ||||
{ | { | ||||
return task; | return task; | ||||
@@ -127,17 +131,18 @@ namespace MQTTnet.Server | |||||
_cancellationToken.Dispose(); | _cancellationToken.Dispose(); | ||||
} | } | ||||
public Task<MqttClientDisconnectType> RunAsync(MqttConnectionValidatorContext connectionValidatorContext) | |||||
public Task RunAsync(MqttConnectionValidatorContext connectionValidatorContext) | |||||
{ | { | ||||
_packageReceiverTask = RunInternalAsync(connectionValidatorContext); | _packageReceiverTask = RunInternalAsync(connectionValidatorContext); | ||||
return _packageReceiverTask; | return _packageReceiverTask; | ||||
} | } | ||||
async Task<MqttClientDisconnectType> RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) | |||||
async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) | |||||
{ | { | ||||
var disconnectType = MqttClientDisconnectType.NotClean; | var disconnectType = MqttClientDisconnectType.NotClean; | ||||
try | try | ||||
{ | { | ||||
await _onStart(); | |||||
_logger.Info("Client '{0}': Session started.", ClientId); | _logger.Info("Client '{0}': Session started.", ClientId); | ||||
_channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; | _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; | ||||
@@ -241,6 +246,11 @@ namespace MQTTnet.Server | |||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
if (_isTakeover) | |||||
{ | |||||
disconnectType = MqttClientDisconnectType.Takeover; | |||||
} | |||||
if (Session.WillMessage != null) | if (Session.WillMessage != null) | ||||
{ | { | ||||
_sessionsManager.DispatchApplicationMessage(Session.WillMessage, this); | _sessionsManager.DispatchApplicationMessage(Session.WillMessage, this); | ||||
@@ -255,14 +265,16 @@ namespace MQTTnet.Server | |||||
_logger.Info("Client '{0}': Connection stopped.", ClientId); | _logger.Info("Client '{0}': Connection stopped.", ClientId); | ||||
_packageReceiverTask = null; | _packageReceiverTask = null; | ||||
} | |||||
if (_isTakeover) | |||||
{ | |||||
return MqttClientDisconnectType.Takeover; | |||||
try | |||||
{ | |||||
await _onStop(disconnectType); | |||||
} | |||||
catch (Exception e) | |||||
{ | |||||
_logger.Error(e, "client '{0}': Error while cleaning up", ClientId); | |||||
} | |||||
} | } | ||||
return disconnectType; | |||||
} | } | ||||
void StopInternal() | void StopInternal() | ||||
@@ -185,6 +185,12 @@ namespace MQTTnet.Server | |||||
} | } | ||||
var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false); | var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false); | ||||
if (!dequeueResult.IsSuccess) | |||||
{ | |||||
return; | |||||
} | |||||
var queuedApplicationMessage = dequeueResult.Item; | var queuedApplicationMessage = dequeueResult.Item; | ||||
var sender = queuedApplicationMessage.Sender; | var sender = queuedApplicationMessage.Sender; | ||||
@@ -235,12 +241,9 @@ namespace MQTTnet.Server | |||||
async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) | async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) | ||||
{ | { | ||||
var disconnectType = MqttClientDisconnectType.NotClean; | |||||
string clientId = null; | string clientId = null; | ||||
var clientWasAuthorized = false; | |||||
MqttConnectPacket connectPacket; | MqttConnectPacket connectPacket; | ||||
MqttClientConnection clientConnection = null; | |||||
try | try | ||||
{ | { | ||||
try | try | ||||
@@ -271,13 +274,17 @@ namespace MQTTnet.Server | |||||
return; | return; | ||||
} | } | ||||
clientWasAuthorized = true; | |||||
clientId = connectPacket.ClientId; | clientId = connectPacket.ClientId; | ||||
clientConnection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); | |||||
await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false); | |||||
var connection = await CreateClientConnectionAsync( | |||||
connectPacket, | |||||
connectionValidatorContext, | |||||
channelAdapter, | |||||
async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false), | |||||
async disconnectType => await CleanUpClient(clientId, channelAdapter, disconnectType) | |||||
).ConfigureAwait(false); | |||||
disconnectType = await clientConnection.RunAsync(connectionValidatorContext).ConfigureAwait(false); | |||||
await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); | |||||
} | } | ||||
catch (OperationCanceledException) | catch (OperationCanceledException) | ||||
{ | { | ||||
@@ -286,34 +293,25 @@ namespace MQTTnet.Server | |||||
{ | { | ||||
_logger.Error(exception, exception.Message); | _logger.Error(exception, exception.Message); | ||||
} | } | ||||
finally | |||||
{ | |||||
if (clientWasAuthorized && disconnectType != MqttClientDisconnectType.Takeover) | |||||
{ | |||||
// Only cleanup if the client was authorized. If not it will remove the existing connection, session etc. | |||||
// This allows to kill connections and sessions from known client IDs. | |||||
if (clientId != null) | |||||
{ | |||||
_connections.TryRemove(clientId, out _); | |||||
if (!_options.EnablePersistentSessions) | |||||
{ | |||||
await DeleteSessionAsync(clientId).ConfigureAwait(false); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false); | |||||
private async Task CleanUpClient(string clientId, IMqttChannelAdapter channelAdapter, MqttClientDisconnectType disconnectType) | |||||
{ | |||||
if (clientId != null) | |||||
{ | |||||
_connections.TryRemove(clientId, out _); | |||||
if (clientWasAuthorized && clientId != null) | |||||
if (!_options.EnablePersistentSessions) | |||||
{ | { | ||||
await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); | |||||
await DeleteSessionAsync(clientId).ConfigureAwait(false); | |||||
} | } | ||||
} | |||||
if (clientConnection != null) | |||||
{ | |||||
clientConnection.IsFinalized = true; | |||||
} | |||||
await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false); | |||||
if (clientId != null) | |||||
{ | |||||
await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); | |||||
} | } | ||||
} | } | ||||
@@ -345,7 +343,7 @@ namespace MQTTnet.Server | |||||
return context; | return context; | ||||
} | } | ||||
async Task<MqttClientConnection> CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) | |||||
async Task<MqttClientConnection> CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter, Func<Task> onStart, Func<MqttClientDisconnectType, Task> onStop) | |||||
{ | { | ||||
using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) | using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) | ||||
{ | { | ||||
@@ -354,13 +352,7 @@ namespace MQTTnet.Server | |||||
var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); | var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); | ||||
if (isConnectionPresent) | if (isConnectionPresent) | ||||
{ | { | ||||
await existingConnection.StopAsync(true); | |||||
// TODO: This fixes a race condition with unit test Same_Client_Id_Connect_Disconnect_Event_Order. | |||||
// It is not clear where the issue is coming from. The connected event is fired BEFORE the disconnected | |||||
// event. This is wrong. It seems that the finally block in HandleClientAsync must be finished before we | |||||
// can continue here. Maybe there is a better way to do this. | |||||
SpinWait.SpinUntil(() => existingConnection.IsFinalized, TimeSpan.FromSeconds(10)); | |||||
await existingConnection.StopAsync(true).ConfigureAwait(false); | |||||
} | } | ||||
if (isSessionPresent) | if (isSessionPresent) | ||||
@@ -383,7 +375,7 @@ namespace MQTTnet.Server | |||||
_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | ||||
} | } | ||||
var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, _logger); | |||||
var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, onStart, onStop, _logger); | |||||
_connections[connection.ClientId] = connection; | _connections[connection.ClientId] = connection; | ||||
_sessions[session.ClientId] = session; | _sessions[session.ClientId] = session; | ||||