diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index e71d1a8..ed36b6a 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -118,13 +118,13 @@ namespace MQTTnet.Server _cancellationToken.Dispose(); } - public Task RunAsync() + public Task RunAsync(MqttConnectionValidatorContext connectionValidatorContext) { - _packageReceiverTask = RunInternalAsync(); + _packageReceiverTask = RunInternalAsync(connectionValidatorContext); return _packageReceiverTask; } - private async Task RunInternalAsync() + private async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) { var disconnectType = MqttClientDisconnectType.NotClean; try @@ -142,12 +142,8 @@ namespace MQTTnet.Server _keepAliveMonitor.Start(ConnectPacket.KeepAlivePeriod, _cancellationToken.Token); await SendAsync( - new MqttConnAckPacket - { - ReturnCode = MqttConnectReturnCode.ConnectionAccepted, - ReasonCode = MqttConnectReasonCode.Success, - IsSessionPresent = !Session.IsCleanSession - }).ConfigureAwait(false); + _channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext) + ).ConfigureAwait(false); Session.IsCleanSession = false; diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index db70e95..82d3c76 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -240,10 +240,10 @@ namespace MQTTnet.Server return; } - clientId = connectPacket.ClientId; - var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); + clientId = connectPacket.ClientId; + if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { // Send failure response here without preparing a session. The result for a successful connect @@ -258,7 +258,7 @@ namespace MQTTnet.Server await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); - disconnectType = await connection.RunAsync().ConfigureAwait(false); + disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); } catch (OperationCanceledException) { diff --git a/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs b/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs index 3b813c3..c4c7634 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; @@ -57,6 +58,65 @@ namespace MQTTnet.Tests.MQTTv5 Assert.AreEqual(2, receivedMessage.UserProperties.Count); } } + [TestMethod] + public async Task Connect_With_AssignedClientId() + { + using (var testEnvironment = new TestEnvironment()) + { + string serverConnectedClientId = null; + string serverDisconnectedClientId = null; + string clientAssignedClientId = null; + + // Arrange server + var disconnectedMre = new ManualResetEventSlim(); + var serverOptions = new MqttServerOptionsBuilder() + .WithConnectionValidator((context) => + { + if (string.IsNullOrEmpty(context.ClientId)) + { + context.AssignedClientIdentifier = "test123"; + context.ReasonCode = MqttConnectReasonCode.Success; + } + }); + await testEnvironment.StartServerAsync(serverOptions); + testEnvironment.Server.UseClientConnectedHandler((args) => + { + serverConnectedClientId = args.ClientId; + }); + testEnvironment.Server.UseClientDisconnectedHandler((args) => + { + serverDisconnectedClientId = args.ClientId; + disconnectedMre.Set(); + }); + + // Arrange client + var client = testEnvironment.CreateClient(); + client.UseConnectedHandler((args) => + { + clientAssignedClientId = args.AuthenticateResult.AssignedClientIdentifier; + }); + + // Act + await client.ConnectAsync(new MqttClientOptionsBuilder() + .WithTcpServer("127.0.0.1", testEnvironment.ServerPort) + .WithProtocolVersion(MqttProtocolVersion.V500) + .WithClientId(null) + .Build()); + await client.DisconnectAsync(); + + // Wait for ClientDisconnectedHandler to trigger + disconnectedMre.Wait(500); + + // Assert + Assert.IsNotNull(serverConnectedClientId); + Assert.IsNotNull(serverDisconnectedClientId); + Assert.IsNotNull(clientAssignedClientId); + Assert.AreEqual("test123", serverConnectedClientId); + Assert.AreEqual("test123", serverDisconnectedClientId); + Assert.AreEqual("test123", clientAssignedClientId); + + } + } [TestMethod] public async Task Connect()