diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index 362af39..fb7c471 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -142,15 +142,17 @@ namespace MQTTnet.Server Session.WillMessage = ConnectPacket.WillMessage; - Task.Run(() => SendPendingPacketsAsync(_cancellationToken.Token), _cancellationToken.Token).Forget(_logger); + var cancellationToken = _cancellationToken.Token; - await SendAsync(_channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(_connectionValidatorContext)).ConfigureAwait(false); + Task.Run(() => SendPendingPacketsAsync(cancellationToken), cancellationToken).Forget(_logger); + + await SendAsync(_channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(_connectionValidatorContext), cancellationToken).ConfigureAwait(false); Session.IsCleanSession = false; - while (!_cancellationToken.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { - var packet = await _channelAdapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationToken.Token).ConfigureAwait(false); + var packet = await _channelAdapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); if (packet == null) { // The client has closed the connection gracefully. @@ -167,7 +169,7 @@ namespace MQTTnet.Server if (packet is MqttPublishPacket publishPacket) { - await HandleIncomingPublishPacketAsync(publishPacket).ConfigureAwait(false); + await HandleIncomingPublishPacketAsync(publishPacket, cancellationToken).ConfigureAwait(false); continue; } @@ -179,25 +181,25 @@ namespace MQTTnet.Server ReasonCode = MqttPubCompReasonCode.Success }; - await SendAsync(pubCompPacket).ConfigureAwait(false); + await SendAsync(pubCompPacket, cancellationToken).ConfigureAwait(false); continue; } if (packet is MqttSubscribePacket subscribePacket) { - await HandleIncomingSubscribePacketAsync(subscribePacket).ConfigureAwait(false); + await HandleIncomingSubscribePacketAsync(subscribePacket, cancellationToken).ConfigureAwait(false); continue; } if (packet is MqttUnsubscribePacket unsubscribePacket) { - await HandleIncomingUnsubscribePacketAsync(unsubscribePacket).ConfigureAwait(false); + await HandleIncomingUnsubscribePacketAsync(unsubscribePacket, cancellationToken).ConfigureAwait(false); continue; } if (packet is MqttPingReqPacket) { - await SendAsync(new MqttPingRespPacket()).ConfigureAwait(false); + await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false); continue; } @@ -289,12 +291,12 @@ namespace MQTTnet.Server } } - async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) + async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); var subAckPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreateSubAckPacket(subscribePacket, subscribeResult); - await SendAsync(subAckPacket).ConfigureAwait(false); + await SendAsync(subAckPacket, cancellationToken).ConfigureAwait(false); if (subscribeResult.CloseConnection) { @@ -305,15 +307,15 @@ namespace MQTTnet.Server await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); } - async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket) + async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { var reasonCodes = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); var unsubAckPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreateUnsubAckPacket(unsubscribePacket, reasonCodes); - await SendAsync(unsubAckPacket).ConfigureAwait(false); + await SendAsync(unsubAckPacket, cancellationToken).ConfigureAwait(false); } - Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) + Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { Interlocked.Increment(ref _sentApplicationMessagesCount); @@ -325,11 +327,11 @@ namespace MQTTnet.Server } case MqttQualityOfServiceLevel.AtLeastOnce: { - return HandleIncomingPublishPacketWithQoS1Async(publishPacket); + return HandleIncomingPublishPacketWithQoS1Async(publishPacket, cancellationToken); } case MqttQualityOfServiceLevel.ExactlyOnce: { - return HandleIncomingPublishPacketWithQoS2Async(publishPacket); + return HandleIncomingPublishPacketWithQoS2Async(publishPacket, cancellationToken); } default: { @@ -347,16 +349,16 @@ namespace MQTTnet.Server return PlatformAbstractionLayer.CompletedTask; } - Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) + Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); _sessionsManager.DispatchApplicationMessage(applicationMessage, this); var pubAckPacket = _dataConverter.CreatePubAckPacket(publishPacket); - return SendAsync(pubAckPacket); + return SendAsync(pubAckPacket, cancellationToken); } - Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket) + Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); _sessionsManager.DispatchApplicationMessage(applicationMessage, this); @@ -367,7 +369,7 @@ namespace MQTTnet.Server ReasonCode = MqttPubRecReasonCode.Success }; - return SendAsync(pubRecPacket); + return SendAsync(pubRecPacket, cancellationToken); } async Task SendPendingPacketsAsync(CancellationToken cancellationToken) @@ -426,12 +428,12 @@ namespace MQTTnet.Server if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) { - await SendAsync(publishPacket).ConfigureAwait(false); + await SendAsync(publishPacket, cancellationToken).ConfigureAwait(false); } else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) { var awaiter = _packetDispatcher.AddAwaiter(publishPacket.PacketIdentifier); - await SendAsync(publishPacket).ConfigureAwait(false); + await SendAsync(publishPacket, cancellationToken).ConfigureAwait(false); await awaiter.WaitOneAsync(_serverOptions.DefaultCommunicationTimeout).ConfigureAwait(false); } else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) @@ -439,10 +441,10 @@ namespace MQTTnet.Server using (var awaiter1 = _packetDispatcher.AddAwaiter(publishPacket.PacketIdentifier)) using (var awaiter2 = _packetDispatcher.AddAwaiter(publishPacket.PacketIdentifier)) { - await SendAsync(publishPacket).ConfigureAwait(false); + await SendAsync(publishPacket, cancellationToken).ConfigureAwait(false); await awaiter1.WaitOneAsync(_serverOptions.DefaultCommunicationTimeout).ConfigureAwait(false); - await SendAsync(new MqttPubRelPacket { PacketIdentifier = publishPacket.PacketIdentifier }).ConfigureAwait(false); + await SendAsync(new MqttPubRelPacket { PacketIdentifier = publishPacket.PacketIdentifier }, cancellationToken).ConfigureAwait(false); await awaiter2.WaitOneAsync(_serverOptions.DefaultCommunicationTimeout).ConfigureAwait(false); } } @@ -482,9 +484,9 @@ namespace MQTTnet.Server } } - async Task SendAsync(MqttBasePacket packet) + async Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, _cancellationToken.Token).ConfigureAwait(false); + await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); Interlocked.Increment(ref _receivedPacketsCount);