@@ -561,12 +561,16 @@ namespace MQTTnet.Client | |||||
{ | { | ||||
await SendAsync(new MqttPingRespPacket(), cancellationToken).ConfigureAwait(false); | await SendAsync(new MqttPingRespPacket(), cancellationToken).ConfigureAwait(false); | ||||
} | } | ||||
else if (packet is MqttDisconnectPacket) | |||||
else if (packet is MqttDisconnectPacket disc) | |||||
{ | { | ||||
// Also dispatch disconnect to waiting threads to generate a proper exception. | // Also dispatch disconnect to waiting threads to generate a proper exception. | ||||
_packetDispatcher.Dispatch(packet); | _packetDispatcher.Dispatch(packet); | ||||
await DisconnectAsync(null, cancellationToken).ConfigureAwait(false); | |||||
await DisconnectAsync(new MqttClientDisconnectOptions() | |||||
{ | |||||
// todo conversion | |||||
ReasonCode = disc.ReasonCode | |||||
}, cancellationToken).ConfigureAwait(false); | |||||
} | } | ||||
else if (packet is MqttAuthPacket authPacket) | else if (packet is MqttAuthPacket authPacket) | ||||
{ | { | ||||
@@ -257,6 +257,15 @@ namespace MQTTnet.Server | |||||
Session.WillMessage = null; | Session.WillMessage = null; | ||||
} | } | ||||
if (_isTakeover) | |||||
{ | |||||
// dont use SendAsync here _cancellationToken is already cancelled | |||||
await _channelAdapter.SendPacketAsync(new MqttDisconnectPacket() | |||||
{ | |||||
ReasonCode = MqttDisconnectReasonCode.SessionTakenOver | |||||
}, TimeSpan.Zero, CancellationToken.None).ConfigureAwait(false); | |||||
} | |||||
_packetDispatcher.Reset(); | _packetDispatcher.Reset(); | ||||
_channelAdapter.ReadingPacketStartedCallback = null; | _channelAdapter.ReadingPacketStartedCallback = null; | ||||
@@ -316,7 +316,11 @@ namespace MQTTnet.Server | |||||
{ | { | ||||
if (clientId != null) | if (clientId != null) | ||||
{ | { | ||||
_connections.TryRemove(clientId, out _); | |||||
// in case it is a takeover _connections already contains the new connection | |||||
if (disconnectType != MqttClientDisconnectType.Takeover) | |||||
{ | |||||
_connections.TryRemove(clientId, out _); | |||||
} | |||||
if (!_options.EnablePersistentSessions) | if (!_options.EnablePersistentSessions) | ||||
{ | { | ||||
@@ -364,38 +368,39 @@ namespace MQTTnet.Server | |||||
{ | { | ||||
using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) | using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) | ||||
{ | { | ||||
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session); | |||||
var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); | |||||
if (isConnectionPresent) | |||||
var session = _sessions.AddOrUpdate(connectPacket.ClientId, key => | |||||
{ | { | ||||
await existingConnection.StopAsync(true).ConfigureAwait(false); | |||||
} | |||||
if (isSessionPresent) | |||||
_logger.Verbose("Created a new session for client '{0}'.", key); | |||||
return new MqttClientSession(key, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _rootLogger); | |||||
}, (key, existingSession) => | |||||
{ | { | ||||
if (connectPacket.CleanSession) | if (connectPacket.CleanSession) | ||||
{ | { | ||||
session = null; | |||||
_logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId); | _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId); | ||||
return new MqttClientSession(key, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _rootLogger); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
_logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId); | _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId); | ||||
return existingSession; | |||||
} | } | ||||
} | |||||
if (session == null) | |||||
{ | |||||
session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _rootLogger); | |||||
_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | |||||
} | |||||
}); | |||||
var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, onStart, onStop, _rootLogger); | var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, onStart, onStop, _rootLogger); | ||||
MqttClientConnection existingConnection = null; | |||||
_connections.AddOrUpdate(connectPacket.ClientId, key => | |||||
{ | |||||
return connection; | |||||
}, (key, tempexistingConnection) => | |||||
{ | |||||
existingConnection = tempexistingConnection; | |||||
return connection; | |||||
}); | |||||
_connections[connection.ClientId] = connection; | |||||
_sessions[session.ClientId] = session; | |||||
if (existingConnection != null) | |||||
{ | |||||
await existingConnection.StopAsync(true).ConfigureAwait(false); | |||||
} | |||||
return connection; | return connection; | ||||
} | } | ||||
@@ -1,5 +1,6 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using MQTTnet.Client; | using MQTTnet.Client; | ||||
using MQTTnet.Client.Options; | |||||
using MQTTnet.Client.Subscribing; | using MQTTnet.Client.Subscribing; | ||||
using MQTTnet.Server; | using MQTTnet.Server; | ||||
using MQTTnet.Tests.Mockups; | using MQTTnet.Tests.Mockups; | ||||
@@ -84,5 +85,40 @@ namespace MQTTnet.Tests | |||||
Assert.AreEqual(true, session.Items["can_subscribe_x"]); | Assert.AreEqual(true, session.Items["can_subscribe_x"]); | ||||
} | } | ||||
} | } | ||||
[TestMethod] | |||||
public async Task Manage_Session_MaxParallel() | |||||
{ | |||||
using (var testEnvironment = new TestEnvironment(TestContext)) | |||||
{ | |||||
testEnvironment.IgnoreClientLogErrors = true; | |||||
var serverOptions = new MqttServerOptionsBuilder(); | |||||
await testEnvironment.StartServerAsync(serverOptions); | |||||
var options = new MqttClientOptionsBuilder() | |||||
.WithClientId("1") | |||||
; | |||||
var clients = await Task.WhenAll(Enumerable.Range(0, 10) | |||||
.Select(i => TryConnect(testEnvironment, options))); | |||||
var connectedClients = clients.Where(c => c?.IsConnected ?? false).ToList(); | |||||
Assert.AreEqual(1, connectedClients.Count); | |||||
} | |||||
} | |||||
private async Task<IMqttClient> TryConnect(TestEnvironment testEnvironment, MqttClientOptionsBuilder options) | |||||
{ | |||||
try | |||||
{ | |||||
return await testEnvironment.ConnectClientAsync(options); | |||||
} | |||||
catch (System.Exception) | |||||
{ | |||||
return null; | |||||
} | |||||
} | |||||
} | } | ||||
} | } |