diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index bb81a8d..a905b00 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -14,6 +14,7 @@ * [Server] Providing the used protocol version of connected clients * [Client] Added support for protocol version 3.1.0 * [Core] Several minor performance improvements +* [Core] Fixed an issue with connection management (Thanks to wuzhenda; Zuendelmeister) Copyright Christian Kratky 2016-2017 MQTT MQTTClient MQTTServer MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Queue Hardware Arduino diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs index 3a43a3e..c9c080d 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs @@ -12,12 +12,11 @@ namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable { - private readonly Socket _socket; + private Socket _socket; private SslStream _sslStream; public MqttTcpChannel() { - _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); } public MqttTcpChannel(Socket socket, SslStream sslStream) @@ -31,6 +30,11 @@ namespace MQTTnet.Implementations if (options == null) throw new ArgumentNullException(nameof(options)); try { + if (_socket == null) + { + _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + } + await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null); if (options.TlsOptions.UseTls) @@ -49,8 +53,7 @@ namespace MQTTnet.Implementations { try { - _sslStream.Dispose(); - _socket.Dispose(); + Dispose(); return Task.FromResult(0); } catch (SocketException exception) @@ -108,6 +111,9 @@ namespace MQTTnet.Implementations { _socket?.Dispose(); _sslStream?.Dispose(); + + _socket = null; + _sslStream = null; } private static X509CertificateCollection LoadCertificates(MqttClientOptions options) diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index e78cd98..a4247b0 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -12,12 +12,11 @@ namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable { - private readonly Socket _socket; + private Socket _socket; private SslStream _sslStream; public MqttTcpChannel() { - _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); } public MqttTcpChannel(Socket socket, SslStream sslStream) @@ -31,6 +30,11 @@ namespace MQTTnet.Implementations if (options == null) throw new ArgumentNullException(nameof(options)); try { + if (_socket == null) + { + _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + } + await _socket.ConnectAsync(options.Server, options.GetPort()); if (options.TlsOptions.UseTls) @@ -49,8 +53,7 @@ namespace MQTTnet.Implementations { try { - _sslStream.Dispose(); - _socket.Dispose(); + Dispose(); return Task.FromResult(0); } catch (SocketException exception) @@ -101,6 +104,9 @@ namespace MQTTnet.Implementations { _socket?.Dispose(); _sslStream?.Dispose(); + + _socket = null; + _sslStream = null; } private static X509CertificateCollection LoadCertificates(MqttClientOptions options) diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs index 99681b7..482ce32 100644 --- a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs @@ -15,11 +15,10 @@ namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable { - private readonly StreamSocket _socket; + private StreamSocket _socket; public MqttTcpChannel() { - _socket = new StreamSocket(); } public MqttTcpChannel(StreamSocket socket) @@ -32,6 +31,11 @@ namespace MQTTnet.Implementations if (options == null) throw new ArgumentNullException(nameof(options)); try { + if (_socket == null) + { + _socket = new StreamSocket(); + } + if (!options.TlsOptions.UseTls) { await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString()); @@ -59,7 +63,7 @@ namespace MQTTnet.Implementations { try { - _socket.Dispose(); + Dispose(); return Task.FromResult(0); } catch (SocketException exception) @@ -100,6 +104,8 @@ namespace MQTTnet.Implementations public void Dispose() { _socket?.Dispose(); + + _socket = null; } private static Certificate LoadCertificate(MqttClientOptions options) diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index d504d89..04b5ca7 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -23,11 +23,7 @@ namespace MQTTnet.Core.Adapter public async Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) { - var task = _channel.ConnectAsync(options); - if (await Task.WhenAny(Task.Delay(timeout), task) != task) - { - throw new MqttCommunicationTimedOutException(); - } + await ExecuteWithTimeoutAsync(_channel.ConnectAsync(options), timeout); } public async Task DisconnectAsync() @@ -39,21 +35,7 @@ namespace MQTTnet.Core.Adapter { MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), $"TX >>> {packet} [Timeout={timeout}]"); - bool hasTimeout; - try - { - var task = PacketSerializer.SerializeAsync(packet, _channel); - hasTimeout = await Task.WhenAny(Task.Delay(timeout), task) != task; - } - catch (Exception exception) - { - throw new MqttCommunicationException(exception); - } - - if (hasTimeout) - { - throw new MqttCommunicationTimedOutException(); - } + await ExecuteWithTimeoutAsync(PacketSerializer.SerializeAsync(packet, _channel), timeout); } public async Task ReceivePacketAsync(TimeSpan timeout) @@ -61,16 +43,7 @@ namespace MQTTnet.Core.Adapter MqttBasePacket packet; if (timeout > TimeSpan.Zero) { - var workerTask = PacketSerializer.DeserializeAsync(_channel); - var timeoutTask = Task.Delay(timeout); - var hasTimeout = Task.WhenAny(timeoutTask, workerTask) == timeoutTask; - - if (hasTimeout) - { - throw new MqttCommunicationTimedOutException(); - } - - packet = workerTask.Result; + packet = await ExecuteWithTimeoutAsync(PacketSerializer.DeserializeAsync(_channel), timeout); } else { @@ -85,5 +58,35 @@ namespace MQTTnet.Core.Adapter MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), $"RX <<< {packet}"); return packet; } + + private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) + { + var timeoutTask = Task.Delay(timeout); + if (await Task.WhenAny(timeoutTask, task) == timeoutTask) + { + throw new MqttCommunicationTimedOutException(); + } + + if (task.IsFaulted) + { + throw new MqttCommunicationException(task.Exception); + } + + return task.Result; + } + + private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) + { + var timeoutTask = Task.Delay(timeout); + if (await Task.WhenAny(timeoutTask, task) == timeoutTask) + { + throw new MqttCommunicationTimedOutException(); + } + + if (task.IsFaulted) + { + throw new MqttCommunicationException(task.Exception); + } + } } } \ No newline at end of file diff --git a/MQTTnet.Core/Client/MqttClient.cs b/MQTTnet.Core/Client/MqttClient.cs index 70e8c76..3481cbe 100644 --- a/MQTTnet.Core/Client/MqttClient.cs +++ b/MQTTnet.Core/Client/MqttClient.cs @@ -20,6 +20,7 @@ namespace MQTTnet.Core.Client private readonly MqttClientOptions _options; private readonly IMqttCommunicationAdapter _adapter; + private bool _disconnectedEventSuspended; private int _latestPacketIdentifier; private CancellationTokenSource _cancellationTokenSource; @@ -48,49 +49,64 @@ namespace MQTTnet.Core.Client throw new MqttProtocolViolationException("It is not allowed to connect with a server after the connection is established."); } - var connectPacket = new MqttConnectPacket + try { - ClientId = _options.ClientId, - Username = _options.UserName, - Password = _options.Password, - CleanSession = _options.CleanSession, - KeepAlivePeriod = (ushort)_options.KeepAlivePeriod.TotalSeconds, - WillMessage = willApplicationMessage - }; + _disconnectedEventSuspended = false; - await _adapter.ConnectAsync(_options, _options.DefaultCommunicationTimeout); - MqttTrace.Verbose(nameof(MqttClient), "Connection with server established."); + await _adapter.ConnectAsync(_options, _options.DefaultCommunicationTimeout); - _cancellationTokenSource = new CancellationTokenSource(); - _latestPacketIdentifier = 0; - _packetDispatcher.Reset(); - IsConnected = true; + MqttTrace.Verbose(nameof(MqttClient), "Connection with server established."); -#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(() => ReceivePackets(_cancellationTokenSource.Token), _cancellationTokenSource.Token); -#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + var connectPacket = new MqttConnectPacket + { + ClientId = _options.ClientId, + Username = _options.UserName, + Password = _options.Password, + CleanSession = _options.CleanSession, + KeepAlivePeriod = (ushort)_options.KeepAlivePeriod.TotalSeconds, + WillMessage = willApplicationMessage + }; + + _cancellationTokenSource = new CancellationTokenSource(); + _latestPacketIdentifier = 0; + _packetDispatcher.Reset(); + + StartReceivePackets(); + + var response = await SendAndReceiveAsync(connectPacket); + if (response.ConnectReturnCode != MqttConnectReturnCode.ConnectionAccepted) + { + await DisconnectInternalAsync(); + throw new MqttConnectingFailedException(response.ConnectReturnCode); + } - var response = await SendAndReceiveAsync(connectPacket); - if (response.ConnectReturnCode != MqttConnectReturnCode.ConnectionAccepted) - { - await DisconnectAsync(); - throw new MqttConnectingFailedException(response.ConnectReturnCode); - } + if (_options.KeepAlivePeriod != TimeSpan.Zero) + { + StartSendKeepAliveMessages(); + } + + MqttTrace.Verbose(nameof(MqttClient), "MQTT connection with server established."); - if (_options.KeepAlivePeriod != TimeSpan.Zero) + IsConnected = true; + Connected?.Invoke(this, EventArgs.Empty); + } + catch (Exception) { -#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(() => SendKeepAliveMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); -#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + await DisconnectInternalAsync(); + throw; } - - Connected?.Invoke(this, EventArgs.Empty); } public async Task DisconnectAsync() { - await SendAsync(new MqttDisconnectPacket()); - await DisconnectInternalAsync(); + try + { + await SendAsync(new MqttDisconnectPacket()); + } + finally + { + await DisconnectInternalAsync(); + } } public Task> SubscribeAsync(params TopicFilter[] topicFilters) @@ -181,8 +197,9 @@ namespace MQTTnet.Core.Client { await _adapter.DisconnectAsync(); } - catch + catch (Exception exception) { + MqttTrace.Warning(nameof(MqttClient), exception, "Error while disconnecting."); } finally { @@ -191,7 +208,12 @@ namespace MQTTnet.Core.Client _cancellationTokenSource = null; IsConnected = false; - Disconnected?.Invoke(this, EventArgs.Empty); + + if (!_disconnectedEventSuspended) + { + _disconnectedEventSuspended = true; + Disconnected?.Invoke(this, EventArgs.Empty); + } } } @@ -239,7 +261,7 @@ namespace MQTTnet.Core.Client } catch (Exception exception) { - MqttTrace.Error(nameof(MqttClient), exception, "Unhandled exception while handling application message."); + MqttTrace.Error(nameof(MqttClient), exception, "Unhandled exception while handling application message."); } } @@ -278,7 +300,7 @@ namespace MQTTnet.Core.Client { _unacknowledgedPublishPackets.Remove(pubRelPacket.PacketIdentifier); } - + await SendAsync(pubRelPacket.CreateResponse()); } @@ -300,15 +322,12 @@ namespace MQTTnet.Core.Client var pi1 = requestPacket as IMqttPacketWithIdentifier; var pi2 = p as IMqttPacketWithIdentifier; - if (pi1 != null && pi2 != null) + if (pi1 == null || pi2 == null) { - if (pi1.PacketIdentifier != pi2.PacketIdentifier) - { - return false; - } + return true; } - return true; + return pi1.PacketIdentifier == pi2.PacketIdentifier; } await _adapter.SendPacketAsync(requestPacket, _options.DefaultCommunicationTimeout); @@ -335,15 +354,16 @@ namespace MQTTnet.Core.Client catch (MqttCommunicationException exception) { MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication error while receiving packets."); + await DisconnectInternalAsync(); } catch (Exception exception) { MqttTrace.Warning(nameof(MqttClient), exception, "Error while sending/receiving keep alive packets."); + await DisconnectInternalAsync(); } finally { MqttTrace.Information(nameof(MqttClient), "Stopped sending keep alive packets."); - await DisconnectInternalAsync(); } } @@ -354,27 +374,47 @@ namespace MQTTnet.Core.Client { while (!cancellationToken.IsCancellationRequested) { - var mqttPacket = await _adapter.ReceivePacketAsync(TimeSpan.Zero); - MqttTrace.Information(nameof(MqttClient), $"Received <<< {mqttPacket}"); + var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero); + MqttTrace.Information(nameof(MqttClient), $"Received <<< {packet}"); -#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(() => ProcessReceivedPacketAsync(mqttPacket), cancellationToken); -#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + StartProcessReceivedPacket(packet, cancellationToken); } } catch (MqttCommunicationException exception) { - MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication error while receiving packets."); + MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication exception while receiving packets."); + await DisconnectInternalAsync(); } catch (Exception exception) { - MqttTrace.Error(nameof(MqttClient), exception, "Error while receiving packets."); + MqttTrace.Error(nameof(MqttClient), exception, "Unhandled exception while receiving packets."); + await DisconnectInternalAsync(); } finally { MqttTrace.Information(nameof(MqttClient), "Stopped receiving packets."); - await DisconnectInternalAsync(); } } + + private void StartProcessReceivedPacket(MqttBasePacket packet, CancellationToken cancellationToken) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + Task.Run(() => ProcessReceivedPacketAsync(packet), cancellationToken); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + private void StartReceivePackets() + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + Task.Run(() => ReceivePackets(_cancellationTokenSource.Token), _cancellationTokenSource.Token); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + private void StartSendKeepAliveMessages() + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + Task.Run(() => SendKeepAliveMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } } } \ No newline at end of file diff --git a/MQTTnet.Core/Exceptions/MqttCommunicationException.cs b/MQTTnet.Core/Exceptions/MqttCommunicationException.cs index 8b471a0..2fc578e 100644 --- a/MQTTnet.Core/Exceptions/MqttCommunicationException.cs +++ b/MQTTnet.Core/Exceptions/MqttCommunicationException.cs @@ -4,7 +4,7 @@ namespace MQTTnet.Core.Exceptions { public class MqttCommunicationException : Exception { - public MqttCommunicationException() + protected MqttCommunicationException() { } @@ -17,5 +17,10 @@ namespace MQTTnet.Core.Exceptions : base(message) { } + + public MqttCommunicationException(string message, Exception innerException) + : base(message, innerException) + { + } } }