diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 8a045be..51a3612 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -561,12 +561,16 @@ namespace MQTTnet.Client { await SendAsync(new MqttPingRespPacket(), cancellationToken).ConfigureAwait(false); } - else if (packet is MqttDisconnectPacket) + else if (packet is MqttDisconnectPacket disc) { // Also dispatch disconnect to waiting threads to generate a proper exception. _packetDispatcher.Dispatch(packet); - await DisconnectAsync(null, cancellationToken).ConfigureAwait(false); + await DisconnectAsync(new MqttClientDisconnectOptions() + { + // todo conversion + ReasonCode = disc.ReasonCode + }, cancellationToken).ConfigureAwait(false); } else if (packet is MqttAuthPacket authPacket) { diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index bf51269..57fb63e 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -257,6 +257,15 @@ namespace MQTTnet.Server Session.WillMessage = null; } + if (_isTakeover) + { + // dont use SendAsync here _cancellationToken is already cancelled + await _channelAdapter.SendPacketAsync(new MqttDisconnectPacket() + { + ReasonCode = MqttDisconnectReasonCode.SessionTakenOver + }, TimeSpan.Zero, CancellationToken.None).ConfigureAwait(false); + } + _packetDispatcher.Reset(); _channelAdapter.ReadingPacketStartedCallback = null; diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 48f6980..6935bcb 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -316,7 +316,11 @@ namespace MQTTnet.Server { if (clientId != null) { - _connections.TryRemove(clientId, out _); + // in case it is a takeover _connections already contains the new connection + if (disconnectType != MqttClientDisconnectType.Takeover) + { + _connections.TryRemove(clientId, out _); + } if (!_options.EnablePersistentSessions) { @@ -364,38 +368,39 @@ namespace MQTTnet.Server { using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) { - var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session); - - var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); - if (isConnectionPresent) + var session = _sessions.AddOrUpdate(connectPacket.ClientId, key => { - await existingConnection.StopAsync(true).ConfigureAwait(false); - } - - if (isSessionPresent) + _logger.Verbose("Created a new session for client '{0}'.", key); + return new MqttClientSession(key, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _rootLogger); + }, (key, existingSession) => { if (connectPacket.CleanSession) { - session = null; - _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId); + return new MqttClientSession(key, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _rootLogger); } else { _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId); + return existingSession; } - } - - if (session == null) - { - session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _rootLogger); - _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); - } + }); var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, onStart, onStop, _rootLogger); + MqttClientConnection existingConnection = null; + _connections.AddOrUpdate(connectPacket.ClientId, key => + { + return connection; + }, (key, tempexistingConnection) => + { + existingConnection = tempexistingConnection; + return connection; + }); - _connections[connection.ClientId] = connection; - _sessions[session.ClientId] = session; + if (existingConnection != null) + { + await existingConnection.StopAsync(true).ConfigureAwait(false); + } return connection; } diff --git a/Tests/MQTTnet.Core.Tests/Session_Tests.cs b/Tests/MQTTnet.Core.Tests/Session_Tests.cs index 973fbf3..1f4c418 100644 --- a/Tests/MQTTnet.Core.Tests/Session_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Session_Tests.cs @@ -1,5 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; +using MQTTnet.Client.Options; using MQTTnet.Client.Subscribing; using MQTTnet.Server; using MQTTnet.Tests.Mockups; @@ -84,5 +85,40 @@ namespace MQTTnet.Tests Assert.AreEqual(true, session.Items["can_subscribe_x"]); } } + + + [TestMethod] + public async Task Manage_Session_MaxParallel() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + testEnvironment.IgnoreClientLogErrors = true; + var serverOptions = new MqttServerOptionsBuilder(); + await testEnvironment.StartServerAsync(serverOptions); + + var options = new MqttClientOptionsBuilder() + .WithClientId("1") + ; + + var clients = await Task.WhenAll(Enumerable.Range(0, 10) + .Select(i => TryConnect(testEnvironment, options))); + + var connectedClients = clients.Where(c => c?.IsConnected ?? false).ToList(); + + Assert.AreEqual(1, connectedClients.Count); + } + } + + private async Task TryConnect(TestEnvironment testEnvironment, MqttClientOptionsBuilder options) + { + try + { + return await testEnvironment.ConnectClientAsync(options); + } + catch (System.Exception) + { + return null; + } + } } }