From e6cfef5295482002671e34fdeec48913c09c0c18 Mon Sep 17 00:00:00 2001 From: JanEggers Date: Wed, 18 Jul 2018 21:13:02 +0200 Subject: [PATCH] fixed disconnect to be triggered just once --- Source/MQTTnet/Client/MqttClient.cs | 24 ++++++++++++++++----- Tests/MQTTnet.Core.Tests/MqttClientTests.cs | 13 ++++++----- Tests/MQTTnet.Core.Tests/MqttServerTests.cs | 6 +++--- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 7bc21de..86264e4 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -29,6 +29,7 @@ namespace MQTTnet.Client private Task _keepAliveMessageSenderTask; private IMqttChannelAdapter _adapter; private bool _cleanDisconnectInitiated; + private TaskCompletionSource _disconnectReason; public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger) { @@ -54,6 +55,7 @@ namespace MQTTnet.Client try { _cancellationTokenSource = new CancellationTokenSource(); + _disconnectReason = new TaskCompletionSource(); _options = options; _packetIdentifierProvider.Reset(); _packetDispatcher.Reset(); @@ -85,8 +87,10 @@ namespace MQTTnet.Client catch (Exception exception) { _logger.Error(exception, "Error while connecting with server."); - await DisconnectInternalAsync(null, exception).ConfigureAwait(false); - + if (_disconnectReason.TrySetException(exception)) + { + await DisconnectInternalAsync(null, exception).ConfigureAwait(false); + } throw; } } @@ -104,7 +108,10 @@ namespace MQTTnet.Client } finally { - await DisconnectInternalAsync(null, null).ConfigureAwait(false); + if (_disconnectReason.TrySetCanceled()) + { + await DisconnectInternalAsync(null, null).ConfigureAwait(false); + } } } @@ -355,7 +362,10 @@ namespace MQTTnet.Client _logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets."); } - await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); + if (_disconnectReason.TrySetException(exception)) + { + await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); + } } finally { @@ -398,7 +408,11 @@ namespace MQTTnet.Client } _packetDispatcher.Dispatch(exception); - await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); + + if (_disconnectReason.TrySetException(exception)) + { + await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); + } } finally { diff --git a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs index 1334a80..455cf7d 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs @@ -1,4 +1,5 @@ -using System.Net.Sockets; +using System; +using System.Net.Sockets; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; @@ -16,10 +17,10 @@ namespace MQTTnet.Core.Tests var factory = new MqttFactory(); var client = factory.CreateMqttClient(); - var exceptionIsCorrect = false; + Exception ex = null; client.Disconnected += (s, e) => { - exceptionIsCorrect = e.Exception is MqttCommunicationException c && c.InnerException is SocketException; + ex = e.Exception; }; try @@ -29,8 +30,10 @@ namespace MQTTnet.Core.Tests catch { } - - Assert.IsTrue(exceptionIsCorrect); + + Assert.IsNotNull(ex); + Assert.IsInstanceOfType(ex, typeof(MqttCommunicationException)); + Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException)); } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 39259d0..117e6ef 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -244,12 +244,12 @@ namespace MQTTnet.Core.Tests .WithTcpServer("localhost") .Build(); - bool disconnectCalled = false; + var disconnectCalled = 0; await s.StartAsync(new MqttServerOptions()); var c1 = new MqttFactory().CreateMqttClient(); - c1.Disconnected += (sender, args) => disconnectCalled = true; + c1.Disconnected += (sender, args) => disconnectCalled++; await c1.ConnectAsync(clientOptions); @@ -259,7 +259,7 @@ namespace MQTTnet.Core.Tests await Task.Delay(100); - Assert.IsTrue(disconnectCalled); + Assert.AreEqual(1, disconnectCalled); } [TestMethod]