@@ -25,10 +25,11 @@ namespace MQTTnet.Client | |||
private IMqttClientOptions _options; | |||
private CancellationTokenSource _cancellationTokenSource; | |||
private Task _packetReceiverTask; | |||
private Task _keepAliveMessageSenderTask; | |||
internal Task _packetReceiverTask; | |||
internal Task _keepAliveMessageSenderTask; | |||
private IMqttChannelAdapter _adapter; | |||
private bool _cleanDisconnectInitiated; | |||
private TaskCompletionSource<bool> _disconnectReason; | |||
public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger) | |||
{ | |||
@@ -54,6 +55,7 @@ namespace MQTTnet.Client | |||
try | |||
{ | |||
_cancellationTokenSource = new CancellationTokenSource(); | |||
_disconnectReason = new TaskCompletionSource<bool>(); | |||
_options = options; | |||
_packetIdentifierProvider.Reset(); | |||
_packetDispatcher.Reset(); | |||
@@ -85,8 +87,10 @@ namespace MQTTnet.Client | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while connecting with server."); | |||
await DisconnectInternalAsync(null, exception).ConfigureAwait(false); | |||
if (_disconnectReason.TrySetException(exception)) | |||
{ | |||
await DisconnectInternalAsync(null, exception).ConfigureAwait(false); | |||
} | |||
throw; | |||
} | |||
} | |||
@@ -104,7 +108,10 @@ namespace MQTTnet.Client | |||
} | |||
finally | |||
{ | |||
await DisconnectInternalAsync(null, null).ConfigureAwait(false); | |||
if (_disconnectReason.TrySetCanceled()) | |||
{ | |||
await DisconnectInternalAsync(null, null).ConfigureAwait(false); | |||
} | |||
} | |||
} | |||
@@ -352,7 +359,10 @@ namespace MQTTnet.Client | |||
_logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets."); | |||
} | |||
await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); | |||
if (_disconnectReason.TrySetException(exception)) | |||
{ | |||
await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); | |||
} | |||
} | |||
finally | |||
{ | |||
@@ -395,7 +405,11 @@ namespace MQTTnet.Client | |||
} | |||
_packetDispatcher.Dispatch(exception); | |||
await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); | |||
if (_disconnectReason.TrySetException(exception)) | |||
{ | |||
await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); | |||
} | |||
} | |||
finally | |||
{ | |||
@@ -87,8 +87,14 @@ namespace MQTTnet.Implementations | |||
public void Dispose() | |||
{ | |||
TryDispose(_stream, () => _stream = null); | |||
TryDispose(_socket, () => _socket = null); | |||
Cleanup(ref _stream, (s) => s.Dispose()); | |||
Cleanup(ref _socket, (s) => { | |||
if (s.Connected) | |||
{ | |||
s.Shutdown(SocketShutdown.Both); | |||
} | |||
s.Dispose(); | |||
}); | |||
} | |||
private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) | |||
@@ -157,11 +163,16 @@ namespace MQTTnet.Implementations | |||
} | |||
} | |||
private static void TryDispose(IDisposable disposable, Action afterDispose) | |||
private static void Cleanup<T>(ref T item, Action<T> handler) where T : class | |||
{ | |||
var temp = item; | |||
item = null; | |||
try | |||
{ | |||
disposable?.Dispose(); | |||
if (temp != null) | |||
{ | |||
handler(temp); | |||
} | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
@@ -169,10 +180,6 @@ namespace MQTTnet.Implementations | |||
catch (NullReferenceException) | |||
{ | |||
} | |||
finally | |||
{ | |||
afterDispose(); | |||
} | |||
} | |||
} | |||
} | |||
@@ -0,0 +1,3 @@ | |||
using System.Runtime.CompilerServices; | |||
[assembly:InternalsVisibleTo("MQTTnet.Core.Tests")] |
@@ -29,6 +29,8 @@ namespace MQTTnet.Server | |||
private MqttApplicationMessage _willMessage; | |||
private bool _wasCleanDisconnect; | |||
private IMqttChannelAdapter _adapter; | |||
private Task<bool> _workerTask; | |||
private IDisposable _cleanupHandle; | |||
public MqttClientSession( | |||
string clientId, | |||
@@ -65,7 +67,13 @@ namespace MQTTnet.Server | |||
status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; | |||
} | |||
public async Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) | |||
public Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) | |||
{ | |||
_workerTask = RunInternalAsync(connectPacket, adapter); | |||
return _workerTask; | |||
} | |||
private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) | |||
{ | |||
if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); | |||
if (adapter == null) throw new ArgumentNullException(nameof(adapter)); | |||
@@ -77,6 +85,10 @@ namespace MQTTnet.Server | |||
adapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted; | |||
_cancellationTokenSource = new CancellationTokenSource(); | |||
//woraround for https://github.com/dotnet/corefx/issues/24430 | |||
_cleanupHandle = _cancellationTokenSource.Token.Register(() => Cleanup()); | |||
//endworkaround | |||
_wasCleanDisconnect = false; | |||
_willMessage = connectPacket.WillMessage; | |||
@@ -114,24 +126,49 @@ namespace MQTTnet.Server | |||
_logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); | |||
} | |||
Stop(MqttClientDisconnectType.NotClean); | |||
Stop(MqttClientDisconnectType.NotClean, true); | |||
} | |||
finally | |||
{ | |||
if (_adapter != null) | |||
await Cleanup().ConfigureAwait(false); | |||
_cleanupHandle?.Dispose(); | |||
_cleanupHandle = null; | |||
_cancellationTokenSource?.Dispose(); | |||
_cancellationTokenSource = null; | |||
} | |||
} | |||
private async Task Cleanup() | |||
{ | |||
try | |||
{ | |||
var adapter = _adapter; | |||
if (adapter == null) | |||
{ | |||
_adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; | |||
_adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; | |||
return; | |||
} | |||
_adapter = null; | |||
_cancellationTokenSource?.Dispose(); | |||
_cancellationTokenSource = null; | |||
adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; | |||
adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; | |||
await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); | |||
adapter.Dispose(); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, exception.Message); | |||
} | |||
} | |||
public void Stop(MqttClientDisconnectType type) | |||
{ | |||
Stop(type, false); | |||
} | |||
private void Stop(MqttClientDisconnectType type, bool isInsideSession) | |||
{ | |||
try | |||
{ | |||
@@ -151,6 +188,11 @@ namespace MQTTnet.Server | |||
} | |||
_willMessage = null; | |||
if (!isInsideSession) | |||
{ | |||
_workerTask?.GetAwaiter().GetResult(); | |||
} | |||
} | |||
finally | |||
{ | |||
@@ -298,18 +340,18 @@ namespace MQTTnet.Server | |||
if (packet is MqttDisconnectPacket) | |||
{ | |||
Stop(MqttClientDisconnectType.Clean); | |||
Stop(MqttClientDisconnectType.Clean, true); | |||
return; | |||
} | |||
if (packet is MqttConnectPacket) | |||
{ | |||
Stop(MqttClientDisconnectType.NotClean); | |||
Stop(MqttClientDisconnectType.NotClean, true); | |||
return; | |||
} | |||
_logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); | |||
Stop(MqttClientDisconnectType.NotClean); | |||
Stop(MqttClientDisconnectType.NotClean, true); | |||
} | |||
private void EnqueueSubscribedRetainedMessages(ICollection<TopicFilter> topicFilters) | |||
@@ -328,7 +370,7 @@ namespace MQTTnet.Server | |||
if (subscribeResult.CloseConnection) | |||
{ | |||
Stop(MqttClientDisconnectType.NotClean); | |||
Stop(MqttClientDisconnectType.NotClean, true); | |||
return; | |||
} | |||
@@ -264,16 +264,6 @@ namespace MQTTnet.Server | |||
} | |||
finally | |||
{ | |||
try | |||
{ | |||
await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); | |||
clientAdapter.Dispose(); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, exception.Message); | |||
} | |||
if (!_options.EnablePersistentSessions) | |||
{ | |||
DeleteSession(clientId); | |||
@@ -106,14 +106,14 @@ namespace MQTTnet.Server | |||
_cancellationTokenSource.Cancel(false); | |||
_clientSessionsManager.Stop(); | |||
foreach (var adapter in _adapters) | |||
{ | |||
adapter.ClientAccepted -= OnClientAccepted; | |||
await adapter.StopAsync().ConfigureAwait(false); | |||
} | |||
_clientSessionsManager.Stop(); | |||
_logger.Info("Stopped."); | |||
Stopped?.Invoke(this, EventArgs.Empty); | |||
} | |||
@@ -1,10 +1,13 @@ | |||
using System; | |||
using System; | |||
using System.Net.Sockets; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Adapter; | |||
using MQTTnet.Client; | |||
using MQTTnet.Diagnostics; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Packets; | |||
using MQTTnet.Implementations; | |||
using MQTTnet.Server; | |||
@@ -19,10 +22,10 @@ namespace MQTTnet.Core.Tests | |||
var factory = new MqttFactory(); | |||
var client = factory.CreateMqttClient(); | |||
var exceptionIsCorrect = false; | |||
Exception ex = null; | |||
client.Disconnected += (s, e) => | |||
{ | |||
exceptionIsCorrect = e.Exception is MqttCommunicationException c && c.InnerException is SocketException; | |||
ex = e.Exception; | |||
}; | |||
try | |||
@@ -32,8 +35,42 @@ namespace MQTTnet.Core.Tests | |||
catch | |||
{ | |||
} | |||
Assert.IsNotNull(ex); | |||
Assert.IsInstanceOfType(ex, typeof(MqttCommunicationException)); | |||
Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException)); | |||
} | |||
[TestMethod] | |||
public async Task ClientCleanupOnAuthentificationFails() | |||
{ | |||
var channel = new TestMqttCommunicationAdapter(); | |||
var channel2 = new TestMqttCommunicationAdapter(); | |||
channel.Partner = channel2; | |||
channel2.Partner = channel; | |||
Task.Run(async () => { | |||
var connect = await channel2.ReceivePacketAsync(TimeSpan.Zero, CancellationToken.None); | |||
await channel2.SendPacketAsync(new MqttConnAckPacket() { ConnectReturnCode = Protocol.MqttConnectReturnCode.ConnectionRefusedNotAuthorized }, CancellationToken.None); | |||
}); | |||
Assert.IsTrue(exceptionIsCorrect); | |||
var fake = new TestMqttCommunicationAdapterFactory(channel); | |||
var client = new MqttClient(fake, new MqttNetLogger()); | |||
try | |||
{ | |||
await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("any-server").Build()); | |||
} | |||
catch (Exception ex) | |||
{ | |||
Assert.IsInstanceOfType(ex, typeof(MqttConnectingFailedException)); | |||
} | |||
Assert.IsTrue(client._packetReceiverTask == null || client._packetReceiverTask.IsCompleted, "receive loop not completed"); | |||
Assert.IsTrue(client._keepAliveMessageSenderTask == null || client._keepAliveMessageSenderTask.IsCompleted, "keepalive loop not completed"); | |||
} | |||
} | |||
} |
@@ -233,6 +233,74 @@ namespace MQTTnet.Core.Tests | |||
await c1.PublishAsync(message); | |||
} | |||
} | |||
[TestMethod] | |||
public async Task MqttServer_ShutdownDisconnectsClientsGracefully() | |||
{ | |||
var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); | |||
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); | |||
var clientOptions = new MqttClientOptionsBuilder() | |||
.WithTcpServer("localhost") | |||
.Build(); | |||
var disconnectCalled = 0; | |||
await s.StartAsync(new MqttServerOptions()); | |||
var c1 = new MqttFactory().CreateMqttClient(); | |||
c1.Disconnected += (sender, args) => disconnectCalled++; | |||
await c1.ConnectAsync(clientOptions); | |||
await Task.Delay(100); | |||
await s.StopAsync(); | |||
await Task.Delay(100); | |||
Assert.AreEqual(1, disconnectCalled); | |||
} | |||
[TestMethod] | |||
public async Task MqttServer_HandleCleanDisconnect() | |||
{ | |||
MqttNetGlobalLogger.LogMessagePublished += (_, e) => | |||
{ | |||
System.Diagnostics.Debug.WriteLine($"[{e.TraceMessage.Timestamp:s}] {e.TraceMessage.Source} {e.TraceMessage.Message}"); | |||
}; | |||
var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); | |||
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); | |||
var clientConnectedCalled = 0; | |||
var clientDisconnectedCalled = 0; | |||
s.ClientConnected += (_, __) => clientConnectedCalled++; | |||
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++; | |||
var clientOptions = new MqttClientOptionsBuilder() | |||
.WithTcpServer("localhost") | |||
.Build(); | |||
await s.StartAsync(new MqttServerOptions()); | |||
var c1 = new MqttFactory().CreateMqttClient(); | |||
await c1.ConnectAsync(clientOptions); | |||
await Task.Delay(100); | |||
await c1.DisconnectAsync(); | |||
await Task.Delay(100); | |||
await s.StopAsync(); | |||
await Task.Delay(100); | |||
Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled); | |||
} | |||
[TestMethod] | |||
public async Task MqttServer_RetainedMessagesFlow() | |||