Browse Source

fixed disconnect to be triggered just once

release/3.x.x
JanEggers 6 years ago
parent
commit
e6cfef5295
3 changed files with 30 additions and 13 deletions
  1. +19
    -5
      Source/MQTTnet/Client/MqttClient.cs
  2. +8
    -5
      Tests/MQTTnet.Core.Tests/MqttClientTests.cs
  3. +3
    -3
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs

+ 19
- 5
Source/MQTTnet/Client/MqttClient.cs View File

@@ -29,6 +29,7 @@ namespace MQTTnet.Client
private Task _keepAliveMessageSenderTask; private Task _keepAliveMessageSenderTask;
private IMqttChannelAdapter _adapter; private IMqttChannelAdapter _adapter;
private bool _cleanDisconnectInitiated; private bool _cleanDisconnectInitiated;
private TaskCompletionSource<bool> _disconnectReason;


public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger) public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger)
{ {
@@ -54,6 +55,7 @@ namespace MQTTnet.Client
try try
{ {
_cancellationTokenSource = new CancellationTokenSource(); _cancellationTokenSource = new CancellationTokenSource();
_disconnectReason = new TaskCompletionSource<bool>();
_options = options; _options = options;
_packetIdentifierProvider.Reset(); _packetIdentifierProvider.Reset();
_packetDispatcher.Reset(); _packetDispatcher.Reset();
@@ -85,8 +87,10 @@ namespace MQTTnet.Client
catch (Exception exception) catch (Exception exception)
{ {
_logger.Error(exception, "Error while connecting with server."); _logger.Error(exception, "Error while connecting with server.");
await DisconnectInternalAsync(null, exception).ConfigureAwait(false);

if (_disconnectReason.TrySetException(exception))
{
await DisconnectInternalAsync(null, exception).ConfigureAwait(false);
}
throw; throw;
} }
} }
@@ -104,7 +108,10 @@ namespace MQTTnet.Client
} }
finally finally
{ {
await DisconnectInternalAsync(null, null).ConfigureAwait(false);
if (_disconnectReason.TrySetCanceled())
{
await DisconnectInternalAsync(null, null).ConfigureAwait(false);
}
} }
} }


@@ -355,7 +362,10 @@ namespace MQTTnet.Client
_logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets."); _logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets.");
} }


await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false);
if (_disconnectReason.TrySetException(exception))
{
await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false);
}
} }
finally finally
{ {
@@ -398,7 +408,11 @@ namespace MQTTnet.Client
} }


_packetDispatcher.Dispatch(exception); _packetDispatcher.Dispatch(exception);
await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false);

if (_disconnectReason.TrySetException(exception))
{
await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false);
}
} }
finally finally
{ {


+ 8
- 5
Tests/MQTTnet.Core.Tests/MqttClientTests.cs View File

@@ -1,4 +1,5 @@
using System.Net.Sockets;
using System;
using System.Net.Sockets;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Client; using MQTTnet.Client;
@@ -16,10 +17,10 @@ namespace MQTTnet.Core.Tests
var factory = new MqttFactory(); var factory = new MqttFactory();
var client = factory.CreateMqttClient(); var client = factory.CreateMqttClient();


var exceptionIsCorrect = false;
Exception ex = null;
client.Disconnected += (s, e) => client.Disconnected += (s, e) =>
{ {
exceptionIsCorrect = e.Exception is MqttCommunicationException c && c.InnerException is SocketException;
ex = e.Exception;
}; };


try try
@@ -29,8 +30,10 @@ namespace MQTTnet.Core.Tests
catch catch
{ {
} }
Assert.IsTrue(exceptionIsCorrect);

Assert.IsNotNull(ex);
Assert.IsInstanceOfType(ex, typeof(MqttCommunicationException));
Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException));
} }
} }
} }

+ 3
- 3
Tests/MQTTnet.Core.Tests/MqttServerTests.cs View File

@@ -244,12 +244,12 @@ namespace MQTTnet.Core.Tests
.WithTcpServer("localhost") .WithTcpServer("localhost")
.Build(); .Build();


bool disconnectCalled = false;
var disconnectCalled = 0;


await s.StartAsync(new MqttServerOptions()); await s.StartAsync(new MqttServerOptions());


var c1 = new MqttFactory().CreateMqttClient(); var c1 = new MqttFactory().CreateMqttClient();
c1.Disconnected += (sender, args) => disconnectCalled = true;
c1.Disconnected += (sender, args) => disconnectCalled++;


await c1.ConnectAsync(clientOptions); await c1.ConnectAsync(clientOptions);


@@ -259,7 +259,7 @@ namespace MQTTnet.Core.Tests


await Task.Delay(100); await Task.Delay(100);


Assert.IsTrue(disconnectCalled);
Assert.AreEqual(1, disconnectCalled);
} }


[TestMethod] [TestMethod]


Loading…
Cancel
Save