diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index ccedb5a..0fa074c 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -25,10 +25,11 @@ namespace MQTTnet.Client private IMqttClientOptions _options; private CancellationTokenSource _cancellationTokenSource; - private Task _packetReceiverTask; - private Task _keepAliveMessageSenderTask; + internal Task _packetReceiverTask; + internal Task _keepAliveMessageSenderTask; private IMqttChannelAdapter _adapter; private bool _cleanDisconnectInitiated; + private TaskCompletionSource _disconnectReason; public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger) { @@ -54,6 +55,7 @@ namespace MQTTnet.Client try { _cancellationTokenSource = new CancellationTokenSource(); + _disconnectReason = new TaskCompletionSource(); _options = options; _packetIdentifierProvider.Reset(); _packetDispatcher.Reset(); @@ -85,8 +87,10 @@ namespace MQTTnet.Client catch (Exception exception) { _logger.Error(exception, "Error while connecting with server."); - await DisconnectInternalAsync(null, exception).ConfigureAwait(false); - + if (_disconnectReason.TrySetException(exception)) + { + await DisconnectInternalAsync(null, exception).ConfigureAwait(false); + } throw; } } @@ -104,7 +108,10 @@ namespace MQTTnet.Client } finally { - await DisconnectInternalAsync(null, null).ConfigureAwait(false); + if (_disconnectReason.TrySetCanceled()) + { + await DisconnectInternalAsync(null, null).ConfigureAwait(false); + } } } @@ -352,7 +359,10 @@ namespace MQTTnet.Client _logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets."); } - await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); + if (_disconnectReason.TrySetException(exception)) + { + await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); + } } finally { @@ -395,7 +405,11 @@ namespace MQTTnet.Client } _packetDispatcher.Dispatch(exception); - await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); + + if (_disconnectReason.TrySetException(exception)) + { + await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); + } } finally { diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index d7c55a7..bc10d65 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -87,8 +87,14 @@ namespace MQTTnet.Implementations public void Dispose() { - TryDispose(_stream, () => _stream = null); - TryDispose(_socket, () => _socket = null); + Cleanup(ref _stream, (s) => s.Dispose()); + Cleanup(ref _socket, (s) => { + if (s.Connected) + { + s.Shutdown(SocketShutdown.Both); + } + s.Dispose(); + }); } private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) @@ -157,11 +163,16 @@ namespace MQTTnet.Implementations } } - private static void TryDispose(IDisposable disposable, Action afterDispose) + private static void Cleanup(ref T item, Action handler) where T : class { + var temp = item; + item = null; try { - disposable?.Dispose(); + if (temp != null) + { + handler(temp); + } } catch (ObjectDisposedException) { @@ -169,10 +180,6 @@ namespace MQTTnet.Implementations catch (NullReferenceException) { } - finally - { - afterDispose(); - } } } } diff --git a/Source/MQTTnet/Properties/AssemblyInfo.cs b/Source/MQTTnet/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..8259cc1 --- /dev/null +++ b/Source/MQTTnet/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly:InternalsVisibleTo("MQTTnet.Core.Tests")] diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index fa90482..ffc9fbd 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -29,6 +29,8 @@ namespace MQTTnet.Server private MqttApplicationMessage _willMessage; private bool _wasCleanDisconnect; private IMqttChannelAdapter _adapter; + private Task _workerTask; + private IDisposable _cleanupHandle; public MqttClientSession( string clientId, @@ -65,7 +67,13 @@ namespace MQTTnet.Server status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; } - public async Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) + public Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) + { + _workerTask = RunInternalAsync(connectPacket, adapter); + return _workerTask; + } + + private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) { if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); if (adapter == null) throw new ArgumentNullException(nameof(adapter)); @@ -77,6 +85,10 @@ namespace MQTTnet.Server adapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted; _cancellationTokenSource = new CancellationTokenSource(); + + //woraround for https://github.com/dotnet/corefx/issues/24430 + _cleanupHandle = _cancellationTokenSource.Token.Register(() => Cleanup()); + //endworkaround _wasCleanDisconnect = false; _willMessage = connectPacket.WillMessage; @@ -114,24 +126,49 @@ namespace MQTTnet.Server _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); } - Stop(MqttClientDisconnectType.NotClean); + Stop(MqttClientDisconnectType.NotClean, true); } finally { - if (_adapter != null) + await Cleanup().ConfigureAwait(false); + + _cleanupHandle?.Dispose(); + _cleanupHandle = null; + + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + } + } + + private async Task Cleanup() + { + try + { + var adapter = _adapter; + if (adapter == null) { - _adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; - _adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; + return; } _adapter = null; - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; + adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; + adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; + await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); + adapter.Dispose(); + } + catch (Exception exception) + { + _logger.Error(exception, exception.Message); } } public void Stop(MqttClientDisconnectType type) + { + Stop(type, false); + } + + private void Stop(MqttClientDisconnectType type, bool isInsideSession) { try { @@ -151,6 +188,11 @@ namespace MQTTnet.Server } _willMessage = null; + + if (!isInsideSession) + { + _workerTask?.GetAwaiter().GetResult(); + } } finally { @@ -298,18 +340,18 @@ namespace MQTTnet.Server if (packet is MqttDisconnectPacket) { - Stop(MqttClientDisconnectType.Clean); + Stop(MqttClientDisconnectType.Clean, true); return; } if (packet is MqttConnectPacket) { - Stop(MqttClientDisconnectType.NotClean); + Stop(MqttClientDisconnectType.NotClean, true); return; } _logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); - Stop(MqttClientDisconnectType.NotClean); + Stop(MqttClientDisconnectType.NotClean, true); } private void EnqueueSubscribedRetainedMessages(ICollection topicFilters) @@ -328,7 +370,7 @@ namespace MQTTnet.Server if (subscribeResult.CloseConnection) { - Stop(MqttClientDisconnectType.NotClean); + Stop(MqttClientDisconnectType.NotClean, true); return; } diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index c139151..d8f4f14 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -264,16 +264,6 @@ namespace MQTTnet.Server } finally { - try - { - await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); - clientAdapter.Dispose(); - } - catch (Exception exception) - { - _logger.Error(exception, exception.Message); - } - if (!_options.EnablePersistentSessions) { DeleteSession(clientId); diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index dd732f1..6d53291 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -106,14 +106,14 @@ namespace MQTTnet.Server _cancellationTokenSource.Cancel(false); + _clientSessionsManager.Stop(); + foreach (var adapter in _adapters) { adapter.ClientAccepted -= OnClientAccepted; await adapter.StopAsync().ConfigureAwait(false); } - _clientSessionsManager.Stop(); - _logger.Info("Stopped."); Stopped?.Invoke(this, EventArgs.Empty); } diff --git a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs index f629f8f..be899c3 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs @@ -1,10 +1,13 @@ -using System; +using System; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Adapter; using MQTTnet.Client; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; +using MQTTnet.Packets; using MQTTnet.Implementations; using MQTTnet.Server; @@ -19,10 +22,10 @@ namespace MQTTnet.Core.Tests var factory = new MqttFactory(); var client = factory.CreateMqttClient(); - var exceptionIsCorrect = false; + Exception ex = null; client.Disconnected += (s, e) => { - exceptionIsCorrect = e.Exception is MqttCommunicationException c && c.InnerException is SocketException; + ex = e.Exception; }; try @@ -32,8 +35,42 @@ namespace MQTTnet.Core.Tests catch { } + + Assert.IsNotNull(ex); + Assert.IsInstanceOfType(ex, typeof(MqttCommunicationException)); + Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException)); + } + + [TestMethod] + public async Task ClientCleanupOnAuthentificationFails() + { + var channel = new TestMqttCommunicationAdapter(); + var channel2 = new TestMqttCommunicationAdapter(); + channel.Partner = channel2; + channel2.Partner = channel; + + Task.Run(async () => { + var connect = await channel2.ReceivePacketAsync(TimeSpan.Zero, CancellationToken.None); + await channel2.SendPacketAsync(new MqttConnAckPacket() { ConnectReturnCode = Protocol.MqttConnectReturnCode.ConnectionRefusedNotAuthorized }, CancellationToken.None); + }); + + - Assert.IsTrue(exceptionIsCorrect); + var fake = new TestMqttCommunicationAdapterFactory(channel); + + var client = new MqttClient(fake, new MqttNetLogger()); + + try + { + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("any-server").Build()); + } + catch (Exception ex) + { + Assert.IsInstanceOfType(ex, typeof(MqttConnectingFailedException)); + } + + Assert.IsTrue(client._packetReceiverTask == null || client._packetReceiverTask.IsCompleted, "receive loop not completed"); + Assert.IsTrue(client._keepAliveMessageSenderTask == null || client._keepAliveMessageSenderTask.IsCompleted, "keepalive loop not completed"); } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 58b27ba..0c34bde 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -233,6 +233,74 @@ namespace MQTTnet.Core.Tests await c1.PublishAsync(message); } } + + [TestMethod] + public async Task MqttServer_ShutdownDisconnectsClientsGracefully() + { + var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); + var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost") + .Build(); + + var disconnectCalled = 0; + + await s.StartAsync(new MqttServerOptions()); + + var c1 = new MqttFactory().CreateMqttClient(); + c1.Disconnected += (sender, args) => disconnectCalled++; + + await c1.ConnectAsync(clientOptions); + + await Task.Delay(100); + + await s.StopAsync(); + + await Task.Delay(100); + + Assert.AreEqual(1, disconnectCalled); + } + + [TestMethod] + public async Task MqttServer_HandleCleanDisconnect() + { + MqttNetGlobalLogger.LogMessagePublished += (_, e) => + { + System.Diagnostics.Debug.WriteLine($"[{e.TraceMessage.Timestamp:s}] {e.TraceMessage.Source} {e.TraceMessage.Message}"); + }; + + var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); + var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); + + var clientConnectedCalled = 0; + var clientDisconnectedCalled = 0; + + s.ClientConnected += (_, __) => clientConnectedCalled++; + s.ClientDisconnected += (_, __) => clientDisconnectedCalled++; + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost") + .Build(); + + await s.StartAsync(new MqttServerOptions()); + + var c1 = new MqttFactory().CreateMqttClient(); + + await c1.ConnectAsync(clientOptions); + + await Task.Delay(100); + + await c1.DisconnectAsync(); + + await Task.Delay(100); + + await s.StopAsync(); + + await Task.Delay(100); + + Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled); + } [TestMethod] public async Task MqttServer_RetainedMessagesFlow()