diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index c6c35b1..6d639e7 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -18,6 +18,8 @@ * [LowLevelMqttClient] Added low level MQTT client in order to provide more flexibility when using the MQTT protocol. This client requires detailed knowledge about the MQTT protocol. * [Client] Improve connection stability (thanks to @jltjohanlindqvist). * [Client] Support WithConnectionUri to configure client (thanks to @PMExtra). +* [Client] Support PublishAsync with QoS 1 and QoS 2 from within an ApplicationMessageReceivedHandler (#648, #587, thanks to @PSanetra). +* [Client] Fixed MqttCommunicationTimedOutExceptions, caused by a long running ApplicationMessageReceivedHandler, which blocked MQTT packets from being processed (#829, thanks to @PSanetra). * [ManagedClient] Added builder class for MqttClientUnsubscribeOptions (thanks to @dominikviererbe). * [ManagedClient] Added support for persisted sessions (thansk to @PMExtra). * [ManagedClient] Fixed a memory leak (thanks to @zawodskoj). diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 2546fc1..dfcff8a 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -34,6 +34,9 @@ namespace MQTTnet.Client private CancellationTokenSource _backgroundCancellationTokenSource; private Task _packetReceiverTask; private Task _keepAlivePacketsSenderTask; + private Task _publishPacketReceiverTask; + + private AsyncQueue _publishPacketReceiverQueue; private IMqttChannelAdapter _adapter; private bool _cleanDisconnectInitiated; @@ -88,6 +91,9 @@ namespace MQTTnet.Client await _adapter.ConnectAsync(options.CommunicationTimeout, combined.Token).ConfigureAwait(false); _logger.Verbose("Connection with server established."); + _publishPacketReceiverQueue = new AsyncQueue(); + _publishPacketReceiverTask = Task.Run(() => ProcessReceivedPublishPackets(backgroundCancellationToken), backgroundCancellationToken); + _packetReceiverTask = Task.Run(() => TryReceivePacketsAsync(backgroundCancellationToken), backgroundCancellationToken); authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, combined.Token).ConfigureAwait(false); @@ -230,6 +236,9 @@ namespace MQTTnet.Client _backgroundCancellationTokenSource?.Dispose(); _backgroundCancellationTokenSource = null; + _publishPacketReceiverQueue?.Dispose(); + _publishPacketReceiverQueue = null; + _adapter?.Dispose(); _adapter = null; } @@ -300,9 +309,12 @@ namespace MQTTnet.Client try { var receiverTask = WaitForTaskAsync(_packetReceiverTask, sender); + var publishPacketReceiverTask = WaitForTaskAsync(_publishPacketReceiverTask, sender); var keepAliveTask = WaitForTaskAsync(_keepAlivePacketsSenderTask, sender); - await Task.WhenAll(receiverTask, keepAliveTask).ConfigureAwait(false); + await Task.WhenAll(receiverTask, publishPacketReceiverTask, keepAliveTask).ConfigureAwait(false); + + _publishPacketReceiverQueue.Dispose(); } catch (Exception e) { @@ -522,7 +534,7 @@ namespace MQTTnet.Client if (packet is MqttPublishPacket publishPacket) { - await TryProcessReceivedPublishPacketAsync(publishPacket, cancellationToken).ConfigureAwait(false); + EnqueueReceivedPublishPacket(publishPacket); } else if (packet is MqttPubRelPacket pubRelPacket) { @@ -584,47 +596,71 @@ namespace MQTTnet.Client } } - private async Task TryProcessReceivedPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) + private void EnqueueReceivedPublishPacket(MqttPublishPacket publishPacket) { try { - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) - { - await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); - } - else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) + _publishPacketReceiverQueue.Enqueue(publishPacket); + } + catch (Exception exception) + { + _logger.Error(exception, "Error while enqueueing application message."); + } + } + + private async Task ProcessReceivedPublishPackets(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + try { - if (await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false)) + var publishPacketDequeueResult = await _publishPacketReceiverQueue.TryDequeueAsync(cancellationToken); + + if (!publishPacketDequeueResult.IsSuccess) + { + return; + } + + var publishPacket = publishPacketDequeueResult.Item; + + if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) { - await SendAsync(new MqttPubAckPacket + await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); + } + else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) + { + if (await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false)) { - PacketIdentifier = publishPacket.PacketIdentifier, - ReasonCode = MqttPubAckReasonCode.Success - }, cancellationToken).ConfigureAwait(false); + await SendAsync(new MqttPubAckPacket + { + PacketIdentifier = publishPacket.PacketIdentifier, + ReasonCode = MqttPubAckReasonCode.Success + }, cancellationToken).ConfigureAwait(false); + } } - } - else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) - { - if (await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false)) + else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) { - var pubRecPacket = new MqttPubRecPacket + if (await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false)) { - PacketIdentifier = publishPacket.PacketIdentifier, - ReasonCode = MqttPubRecReasonCode.Success - }; + var pubRecPacket = new MqttPubRecPacket + { + PacketIdentifier = publishPacket.PacketIdentifier, + ReasonCode = MqttPubRecReasonCode.Success + }; - await SendAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); + await SendAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); + } + } + else + { + throw new MqttProtocolViolationException("Received a not supported QoS level."); } } - else + catch (Exception exception) { - throw new MqttProtocolViolationException("Received a not supported QoS level."); + _logger.Error(exception, "Error while handling application message."); } } - catch (Exception exception) - { - _logger.Error(exception, "Error while handling application message."); - } } private async Task PublishAtMostOnce(MqttPublishPacket publishPacket, CancellationToken cancellationToken) diff --git a/Source/MQTTnet/Internal/AsyncBlockingQueue.cs b/Source/MQTTnet/Internal/AsyncQueue.cs similarity index 79% rename from Source/MQTTnet/Internal/AsyncBlockingQueue.cs rename to Source/MQTTnet/Internal/AsyncQueue.cs index 6cb80d2..43ad938 100644 --- a/Source/MQTTnet/Internal/AsyncBlockingQueue.cs +++ b/Source/MQTTnet/Internal/AsyncQueue.cs @@ -23,9 +23,15 @@ namespace MQTTnet.Internal { while (!cancellationToken.IsCancellationRequested) { - await _semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - - cancellationToken.ThrowIfCancellationRequested(); + try + { + await _semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + cancellationToken.ThrowIfCancellationRequested(); + } + catch (OperationCanceledException) + { + return new AsyncQueueDequeueResult(false, default(TItem)); + } if (_queue.TryDequeue(out var item)) { diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index a163b81..d22cbe9 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -23,6 +23,8 @@ namespace MQTTnet.Server readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); readonly IMqttRetainedMessagesManager _retainedMessagesManager; + readonly Func _onStart; + readonly Func _onStop; readonly MqttClientKeepAliveMonitor _keepAliveMonitor; readonly MqttClientSessionsManager _sessionsManager; @@ -34,7 +36,7 @@ namespace MQTTnet.Server readonly string _endpoint; readonly DateTime _connectedTimestamp; - Task _packageReceiverTask; + volatile Task _packageReceiverTask; DateTime _lastPacketReceivedTimestamp; DateTime _lastNonKeepAlivePacketReceivedTimestamp; @@ -43,7 +45,7 @@ namespace MQTTnet.Server long _receivedApplicationMessagesCount; long _sentApplicationMessagesCount; - bool _isTakeover; + volatile bool _isTakeover; public MqttClientConnection( MqttConnectPacket connectPacket, @@ -52,12 +54,16 @@ namespace MQTTnet.Server IMqttServerOptions serverOptions, MqttClientSessionsManager sessionsManager, IMqttRetainedMessagesManager retainedMessagesManager, + Func onStart, + Func onStop, IMqttNetLogger logger) { Session = session ?? throw new ArgumentNullException(nameof(session)); _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); + _onStart = onStart ?? throw new ArgumentNullException(nameof(onStart)); + _onStop = onStop ?? throw new ArgumentNullException(nameof(onStop)); _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; @@ -80,15 +86,13 @@ namespace MQTTnet.Server public MqttClientSession Session { get; } - public bool IsFinalized { get; set; } - public Task StopAsync(bool isTakeover = false) { _isTakeover = isTakeover; + var task = _packageReceiverTask; StopInternal(); - var task = _packageReceiverTask; if (task != null) { return task; @@ -127,17 +131,18 @@ namespace MQTTnet.Server _cancellationToken.Dispose(); } - public Task RunAsync(MqttConnectionValidatorContext connectionValidatorContext) + public Task RunAsync(MqttConnectionValidatorContext connectionValidatorContext) { _packageReceiverTask = RunInternalAsync(connectionValidatorContext); return _packageReceiverTask; } - async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) + async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) { var disconnectType = MqttClientDisconnectType.NotClean; try { + await _onStart(); _logger.Info("Client '{0}': Session started.", ClientId); _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; @@ -241,6 +246,11 @@ namespace MQTTnet.Server } finally { + if (_isTakeover) + { + disconnectType = MqttClientDisconnectType.Takeover; + } + if (Session.WillMessage != null) { _sessionsManager.DispatchApplicationMessage(Session.WillMessage, this); @@ -255,14 +265,16 @@ namespace MQTTnet.Server _logger.Info("Client '{0}': Connection stopped.", ClientId); _packageReceiverTask = null; - } - if (_isTakeover) - { - return MqttClientDisconnectType.Takeover; + try + { + await _onStop(disconnectType); + } + catch (Exception e) + { + _logger.Error(e, "client '{0}': Error while cleaning up", ClientId); + } } - - return disconnectType; } void StopInternal() diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 1e2bcef..f98eb05 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -185,6 +185,12 @@ namespace MQTTnet.Server } var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false); + + if (!dequeueResult.IsSuccess) + { + return; + } + var queuedApplicationMessage = dequeueResult.Item; var sender = queuedApplicationMessage.Sender; @@ -235,12 +241,9 @@ namespace MQTTnet.Server async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) { - var disconnectType = MqttClientDisconnectType.NotClean; string clientId = null; - var clientWasAuthorized = false; MqttConnectPacket connectPacket; - MqttClientConnection clientConnection = null; try { try @@ -271,13 +274,17 @@ namespace MQTTnet.Server return; } - clientWasAuthorized = true; clientId = connectPacket.ClientId; - clientConnection = await CreateClientConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); - await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false); + var connection = await CreateClientConnectionAsync( + connectPacket, + connectionValidatorContext, + channelAdapter, + async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false), + async disconnectType => await CleanUpClient(clientId, channelAdapter, disconnectType) + ).ConfigureAwait(false); - disconnectType = await clientConnection.RunAsync(connectionValidatorContext).ConfigureAwait(false); + await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -286,34 +293,25 @@ namespace MQTTnet.Server { _logger.Error(exception, exception.Message); } - finally - { - if (clientWasAuthorized && disconnectType != MqttClientDisconnectType.Takeover) - { - // Only cleanup if the client was authorized. If not it will remove the existing connection, session etc. - // This allows to kill connections and sessions from known client IDs. - if (clientId != null) - { - _connections.TryRemove(clientId, out _); - - if (!_options.EnablePersistentSessions) - { - await DeleteSessionAsync(clientId).ConfigureAwait(false); - } - } - } + } - await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + private async Task CleanUpClient(string clientId, IMqttChannelAdapter channelAdapter, MqttClientDisconnectType disconnectType) + { + if (clientId != null) + { + _connections.TryRemove(clientId, out _); - if (clientWasAuthorized && clientId != null) + if (!_options.EnablePersistentSessions) { - await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + await DeleteSessionAsync(clientId).ConfigureAwait(false); } + } - if (clientConnection != null) - { - clientConnection.IsFinalized = true; - } + await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + + if (clientId != null) + { + await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); } } @@ -345,7 +343,7 @@ namespace MQTTnet.Server return context; } - async Task CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) + async Task CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter, Func onStart, Func onStop) { using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false)) { @@ -354,13 +352,7 @@ namespace MQTTnet.Server var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection); if (isConnectionPresent) { - await existingConnection.StopAsync(true); - - // TODO: This fixes a race condition with unit test Same_Client_Id_Connect_Disconnect_Event_Order. - // It is not clear where the issue is coming from. The connected event is fired BEFORE the disconnected - // event. This is wrong. It seems that the finally block in HandleClientAsync must be finished before we - // can continue here. Maybe there is a better way to do this. - SpinWait.SpinUntil(() => existingConnection.IsFinalized, TimeSpan.FromSeconds(10)); + await existingConnection.StopAsync(true).ConfigureAwait(false); } if (isSessionPresent) @@ -383,7 +375,7 @@ namespace MQTTnet.Server _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } - var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, _logger); + var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, onStart, onStop, _logger); _connections[connection.ClientId] = connection; _sessions[session.ClientId] = session; diff --git a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs index 2b86955..1a81512 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs @@ -12,6 +12,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Net.Sockets; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -434,6 +435,49 @@ namespace MQTTnet.Tests } } + [TestMethod] + public async Task Publish_QoS_1_In_ApplicationMessageReceiveHandler() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + await testEnvironment.StartServerAsync(); + + const string client1Topic = "client1/topic"; + const string client2Topic = "client2/topic"; + const string expectedClient2Message = "hello client2"; + + var client1 = await testEnvironment.ConnectClientAsync(); + client1.UseApplicationMessageReceivedHandler(async c => + { + await client1.PublishAsync(client2Topic, expectedClient2Message, MqttQualityOfServiceLevel.AtLeastOnce); + }); + + await client1.SubscribeAsync(client1Topic, MqttQualityOfServiceLevel.AtLeastOnce); + + var client2 = await testEnvironment.ConnectClientAsync(); + + var client2TopicResults = new List(); + + client2.UseApplicationMessageReceivedHandler(c => + { + client2TopicResults.Add(Encoding.UTF8.GetString(c.ApplicationMessage.Payload)); + }); + + await client2.SubscribeAsync(client2Topic); + + var client3 = await testEnvironment.ConnectClientAsync(); + var message = new MqttApplicationMessageBuilder().WithTopic(client1Topic).Build(); + await client3.PublishAsync(message); + await client3.PublishAsync(message); + + await Task.Delay(500); + + Assert.AreEqual(2, client2TopicResults.Count); + Assert.AreEqual(expectedClient2Message, client2TopicResults[0]); + Assert.AreEqual(expectedClient2Message, client2TopicResults[1]); + } + } + [TestMethod] public async Task Subscribe_In_Callback_Events() { @@ -565,7 +609,7 @@ namespace MQTTnet.Tests for (var i = 0; i < 98; i++) { - Assert.IsFalse(clients[i].IsConnected); + Assert.IsFalse(clients[i].IsConnected, $"clients[{i}] is not connected"); } Assert.IsTrue(clients[99].IsConnected);