diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 82d3c76..2da3efc 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -230,6 +230,7 @@ namespace MQTTnet.Server { var disconnectType = MqttClientDisconnectType.NotClean; string clientId = null; + var clientWasConnected = true; try { @@ -246,6 +247,7 @@ namespace MQTTnet.Server if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { + clientWasConnected = false; // Send failure response here without preparing a session. The result for a successful connect // will be sent from the session itself. var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); @@ -269,21 +271,24 @@ namespace MQTTnet.Server } finally { - if (clientId != null) - { - _connections.TryRemove(clientId, out _); - - if (!_options.EnablePersistentSessions) + if (clientWasConnected) + { + if (clientId != null) { - await DeleteSessionAsync(clientId).ConfigureAwait(false); + _connections.TryRemove(clientId, out _); + + if (!_options.EnablePersistentSessions) + { + await DeleteSessionAsync(clientId).ConfigureAwait(false); + } } - } - await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); - if (clientId != null) - { - await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + if (clientId != null) + { + await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + } } } } diff --git a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs index 8111bd0..c4a5d27 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs @@ -6,6 +6,7 @@ using MQTTnet.Tests.Mockups; using MQTTnet.Client; using MQTTnet.Protocol; using MQTTnet.Server; +using System.Threading; namespace MQTTnet.Tests { @@ -55,10 +56,10 @@ namespace MQTTnet.Tests var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1")); - await Task.Delay(500); + await Task.Delay(1000); var clientStatus = await server.GetClientStatusAsync(); - + Assert.AreEqual(1, clientStatus.Count); Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1")); diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs index b2b3b70..c3d6df1 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -917,6 +917,109 @@ namespace MQTTnet.Tests } } + + private Dictionary _connected; + private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs) + { + if (_connected.ContainsKey(eventArgs.ClientId)) + { + eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; + return; + } + _connected[eventArgs.ClientId] = true; + eventArgs.ReasonCode = MqttConnectReasonCode.Success; + return; + } + + [TestMethod] + public async Task Same_Client_Id_Refuse_Connection() + { + using (var testEnvironment = new TestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + + _connected = new Dictionary(); + var options = new MqttServerOptionsBuilder(); + options.WithConnectionValidator(e => ConnectionValidationHandler(e)); + var server = await testEnvironment.StartServerAsync(options); + + var events = new List(); + + server.ClientConnectedHandler = new MqttServerClientConnectedHandlerDelegate(_ => + { + lock (events) + { + events.Add("c"); + } + }); + + server.ClientDisconnectedHandler = new MqttServerClientDisconnectedHandlerDelegate(_ => + { + lock (events) + { + events.Add("d"); + } + }); + + var clientOptions = new MqttClientOptionsBuilder() + .WithClientId("same_id"); + + // c + var c1 = await testEnvironment.ConnectClientAsync(clientOptions); + + c1.UseDisconnectedHandler(_ => + { + lock (events) + { + events.Add("x"); + } + }); + + + c1.UseApplicationMessageReceivedHandler(_ => + { + lock (events) + { + events.Add("r"); + } + + }); + + c1.SubscribeAsync("topic").Wait(); + + await Task.Delay(500); + + c1.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + var flow = string.Join(string.Empty, events); + Assert.AreEqual("cr", flow); + + try + { + await testEnvironment.ConnectClientAsync(clientOptions); + Assert.Fail("same id connection is expected to fail"); + } + catch + { + //same id connection is expected to fail + } + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cr", flow); + + c1.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("crr", flow); + } + } + [TestMethod] public async Task Same_Client_Id_Connect_Disconnect_Event_Order() { @@ -956,17 +1059,40 @@ namespace MQTTnet.Tests // dc var c2 = await testEnvironment.ConnectClientAsync(clientOptions); + c2.UseApplicationMessageReceivedHandler(_ => + { + lock (events) + { + events.Add("r"); + } + + }); + c2.SubscribeAsync("topic").Wait(); + await Task.Delay(500); flow = string.Join(string.Empty, events); Assert.AreEqual("cdc", flow); + // r + c2.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cdcr", flow); + + // nothing + + Assert.AreEqual(false, c1.IsConnected); await c1.DisconnectAsync(); + Assert.AreEqual (false, c1.IsConnected); await Task.Delay(500); // d + Assert.AreEqual(true, c2.IsConnected); await c2.DisconnectAsync(); await Task.Delay(500); @@ -974,7 +1100,7 @@ namespace MQTTnet.Tests await server.StopAsync(); flow = string.Join(string.Empty, events); - Assert.AreEqual("cdcd", flow); + Assert.AreEqual("cdcrd", flow); } }