From e7c8d1c1c1623dfa9f85d6148d767ad2f4c8e94c Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Tue, 19 Sep 2017 23:06:45 +0200 Subject: [PATCH] Add cancellation token to adapter --- .../Adapter/IMqttCommunicationAdapter.cs | 5 +- .../MqttChannelCommunicationAdapter.cs | 53 ++++-- .../MqttCommunicationAdapterExtensions.cs | 5 +- MQTTnet.Core/Client/IMqttClient.cs | 8 +- MQTTnet.Core/Client/MqttClient.cs | 176 +++++++++--------- MQTTnet.Core/Client/MqttClientExtensions.cs | 25 ++- MQTTnet.Core/Server/MqttClientMessageQueue.cs | 2 +- MQTTnet.Core/Server/MqttClientSession.cs | 16 +- .../Server/MqttClientSessionsManager.cs | 8 +- Tests/MQTTnet.Core.Tests/ExtensionTests.cs | 44 ++++- .../MqttPacketSerializerTests.cs | 3 +- Tests/MQTTnet.Core.Tests/MqttServerTests.cs | 4 +- .../TestMqttCommunicationAdapter.cs | 11 +- 13 files changed, 215 insertions(+), 145 deletions(-) diff --git a/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs b/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs index 866253e..a80fc87 100644 --- a/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Client; using MQTTnet.Core.Packets; @@ -15,8 +16,8 @@ namespace MQTTnet.Core.Adapter Task DisconnectAsync(TimeSpan timeout); - Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets); + Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets); - Task ReceivePacketAsync(TimeSpan timeout); + Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); } } diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index 18d5776..89e9576 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; @@ -30,7 +31,11 @@ namespace MQTTnet.Core.Adapter { try { - await _channel.ConnectAsync(options).TimeoutAfter(timeout); + await _channel.ConnectAsync(options).TimeoutAfter(timeout).ConfigureAwait(false); + } + catch (TaskCanceledException) + { + throw; } catch (MqttCommunicationTimedOutException) { @@ -52,6 +57,10 @@ namespace MQTTnet.Core.Adapter { await _channel.DisconnectAsync().TimeoutAfter(timeout).ConfigureAwait(false); } + catch (TaskCanceledException) + { + throw; + } catch (MqttCommunicationTimedOutException) { throw; @@ -66,7 +75,7 @@ namespace MQTTnet.Core.Adapter } } - public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets) + public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) { try { @@ -77,20 +86,24 @@ namespace MQTTnet.Core.Adapter MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); var writeBuffer = PacketSerializer.Serialize(packet); - _sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); + _sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false), cancellationToken); } } - await _sendTask; // configure await false geneates stackoverflow + await _sendTask; // configure await false generates stackoverflow if (timeout > TimeSpan.Zero) { - await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); + await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); } else { - await _channel.SendStream.FlushAsync().ConfigureAwait(false); - } + await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + catch (TaskCanceledException) + { + throw; } catch (MqttCommunicationTimedOutException) { @@ -106,18 +119,23 @@ namespace MQTTnet.Core.Adapter } } - public async Task ReceivePacketAsync(TimeSpan timeout) + public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { try { ReceivedMqttPacket receivedMqttPacket; if (timeout > TimeSpan.Zero) { - receivedMqttPacket = await ReceiveAsync(_channel.RawReceiveStream).TimeoutAfter(timeout).ConfigureAwait(false); + receivedMqttPacket = await ReceiveAsync(_channel.RawReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); } else { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); + receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); + } + + if (cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException(); } var packet = PacketSerializer.Deserialize(receivedMqttPacket); @@ -129,6 +147,10 @@ namespace MQTTnet.Core.Adapter MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); return packet; } + catch (TaskCanceledException) + { + throw; + } catch (MqttCommunicationTimedOutException) { throw; @@ -143,9 +165,9 @@ namespace MQTTnet.Core.Adapter } } - private static async Task ReceiveAsync(Stream stream) + private static async Task ReceiveAsync(Stream stream, CancellationToken cancellationToken) { - var header = MqttPacketReader.ReadHeaderFromSource(stream); + var header = MqttPacketReader.ReadHeaderFromSource(stream, cancellationToken); if (header.BodyLength == 0) { @@ -157,15 +179,10 @@ namespace MQTTnet.Core.Adapter var offset = 0; do { - var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset).ConfigureAwait(false); + var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset, cancellationToken).ConfigureAwait(false); offset += readBytesCount; } while (offset < header.BodyLength); - if (offset > header.BodyLength) - { - throw new MqttCommunicationException($"Read more body bytes than required ({offset}/{header.BodyLength})."); - } - return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length)); } } diff --git a/MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs b/MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs index d25d172..0fa9ab5 100644 --- a/MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs +++ b/MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Packets; @@ -6,9 +7,9 @@ namespace MQTTnet.Core.Adapter { public static class MqttCommunicationAdapterExtensions { - public static Task SendPacketsAsync(this IMqttCommunicationAdapter adapter, TimeSpan timeout, params MqttBasePacket[] packets) + public static Task SendPacketsAsync(this IMqttCommunicationAdapter adapter, TimeSpan timeout, CancellationToken cancellationToken, params MqttBasePacket[] packets) { - return adapter.SendPacketsAsync(timeout, packets); + return adapter.SendPacketsAsync(timeout, cancellationToken, packets); } } } \ No newline at end of file diff --git a/MQTTnet.Core/Client/IMqttClient.cs b/MQTTnet.Core/Client/IMqttClient.cs index 1b22edf..165f490 100644 --- a/MQTTnet.Core/Client/IMqttClient.cs +++ b/MQTTnet.Core/Client/IMqttClient.cs @@ -15,10 +15,10 @@ namespace MQTTnet.Core.Client Task ConnectAsync(MqttApplicationMessage willApplicationMessage = null); Task DisconnectAsync(); + + Task> SubscribeAsync(IEnumerable topicFilters); + Task UnsubscribeAsync(IEnumerable topicFilters); + Task PublishAsync(IEnumerable applicationMessages); - Task> SubscribeAsync(IList topicFilters); - Task> SubscribeAsync(params TopicFilter[] topicFilters); - Task Unsubscribe(IList topicFilters); - Task Unsubscribe(params string[] topicFilters); } } \ No newline at end of file diff --git a/MQTTnet.Core/Client/MqttClient.cs b/MQTTnet.Core/Client/MqttClient.cs index 4ca4373..f07a85d 100644 --- a/MQTTnet.Core/Client/MqttClient.cs +++ b/MQTTnet.Core/Client/MqttClient.cs @@ -15,12 +15,10 @@ namespace MQTTnet.Core.Client public class MqttClient : IMqttClient { private readonly HashSet _unacknowledgedPublishPackets = new HashSet(); - private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly MqttClientOptions _options; private readonly IMqttCommunicationAdapter _adapter; - private bool _disconnectedEventSuspended; private int _latestPacketIdentifier; private CancellationTokenSource _cancellationTokenSource; @@ -33,30 +31,27 @@ namespace MQTTnet.Core.Client } public event EventHandler Connected; - public event EventHandler Disconnected; - public event EventHandler ApplicationMessageReceived; - public bool IsConnected { get; private set; } + public bool IsConnected => _cancellationTokenSource != null && !_cancellationTokenSource.IsCancellationRequested; public async Task ConnectAsync(MqttApplicationMessage willApplicationMessage = null) { - MqttTrace.Verbose(nameof(MqttClient), "Trying to connect."); - - if (IsConnected) - { - throw new MqttProtocolViolationException("It is not allowed to connect with a server after the connection is established."); - } + ThrowIfConnected("It is not allowed to connect with a server after the connection is established."); try { - _disconnectedEventSuspended = false; - + MqttTrace.Verbose(nameof(MqttClient), "Trying to connect with server."); await _adapter.ConnectAsync(_options.DefaultCommunicationTimeout, _options).ConfigureAwait(false); - MqttTrace.Verbose(nameof(MqttClient), "Connection with server established."); + _cancellationTokenSource = new CancellationTokenSource(); + _latestPacketIdentifier = 0; + _packetDispatcher.Reset(); + + StartReceivePackets(_cancellationTokenSource.Token); + var connectPacket = new MqttConnectPacket { ClientId = _options.ClientId, @@ -67,28 +62,19 @@ namespace MQTTnet.Core.Client WillMessage = willApplicationMessage }; - _cancellationTokenSource = new CancellationTokenSource(); - _latestPacketIdentifier = 0; - _packetDispatcher.Reset(); - - StartReceivePackets(); - var response = await SendAndReceiveAsync(connectPacket).ConfigureAwait(false); if (response.ConnectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await DisconnectInternalAsync().ConfigureAwait(false); throw new MqttConnectingFailedException(response.ConnectReturnCode); } + MqttTrace.Verbose(nameof(MqttClient), "MQTT connection with server established."); + Connected?.Invoke(this, EventArgs.Empty); + if (_options.KeepAlivePeriod != TimeSpan.Zero) { - StartSendKeepAliveMessages(); + StartSendKeepAliveMessages(_cancellationTokenSource.Token); } - - MqttTrace.Verbose(nameof(MqttClient), "MQTT connection with server established."); - - IsConnected = true; - Connected?.Invoke(this, EventArgs.Empty); } catch (Exception) { @@ -114,56 +100,41 @@ namespace MQTTnet.Core.Client } } - public Task> SubscribeAsync(params TopicFilter[] topicFilters) + public async Task> SubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return SubscribeAsync(topicFilters.ToList()); - } - - public async Task> SubscribeAsync(IList topicFilters) - { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - if (!topicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); - ThrowIfNotConnected(); var subscribePacket = new MqttSubscribePacket { PacketIdentifier = GetNewPacketIdentifier(), - TopicFilters = topicFilters + TopicFilters = topicFilters.ToList() }; var response = await SendAndReceiveAsync(subscribePacket).ConfigureAwait(false); - if (response.SubscribeReturnCodes.Count != topicFilters.Count) + if (response.SubscribeReturnCodes.Count != subscribePacket.TopicFilters.Count) { throw new MqttProtocolViolationException("The return codes are not matching the topic filters [MQTT-3.9.3-1]."); } - return topicFilters.Select((t, i) => new MqttSubscribeResult(t, response.SubscribeReturnCodes[i])).ToList(); + return subscribePacket.TopicFilters.Select((t, i) => new MqttSubscribeResult(t, response.SubscribeReturnCodes[i])).ToList(); } - public Task Unsubscribe(params string[] topicFilters) + public async Task UnsubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return Unsubscribe(topicFilters.ToList()); - } - - public Task Unsubscribe(IList topicFilters) - { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - if (!topicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); ThrowIfNotConnected(); var unsubscribePacket = new MqttUnsubscribePacket { PacketIdentifier = GetNewPacketIdentifier(), - TopicFilters = topicFilters + TopicFilters = topicFilters.ToList() }; - return SendAndReceiveAsync(unsubscribePacket); + await SendAndReceiveAsync(unsubscribePacket); } public async Task PublishAsync(IEnumerable applicationMessages) @@ -178,9 +149,11 @@ namespace MQTTnet.Core.Client switch (qosGroup.Key) { case MqttQualityOfServiceLevel.AtMostOnce: - // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] - await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, qosPackets); - break; + { + // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] + await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, qosPackets); + break; + } case MqttQualityOfServiceLevel.AtLeastOnce: { foreach (var publishPacket in qosPackets) @@ -188,6 +161,7 @@ namespace MQTTnet.Core.Client publishPacket.PacketIdentifier = GetNewPacketIdentifier(); await SendAndReceiveAsync(publishPacket); } + break; } case MqttQualityOfServiceLevel.ExactlyOnce: @@ -195,95 +169,100 @@ namespace MQTTnet.Core.Client foreach (var publishPacket in qosPackets) { publishPacket.PacketIdentifier = GetNewPacketIdentifier(); - await PublishExactlyOncePacketAsync(publishPacket); + var pubRecPacket = await SendAndReceiveAsync(publishPacket).ConfigureAwait(false); + await SendAndReceiveAsync(pubRecPacket.CreateResponse()).ConfigureAwait(false); } + break; } default: - throw new InvalidOperationException(); + { + throw new InvalidOperationException(); + } } } } - private async Task PublishExactlyOncePacketAsync(MqttBasePacket publishPacket) + private void ThrowIfNotConnected() { - var pubRecPacket = await SendAndReceiveAsync(publishPacket).ConfigureAwait(false); - await SendAndReceiveAsync(pubRecPacket.CreateResponse()).ConfigureAwait(false); + if (!IsConnected) throw new MqttCommunicationException("The client is not connected."); } - private void ThrowIfNotConnected() + private void ThrowIfConnected(string message) { - if (!IsConnected) throw new MqttCommunicationException("The client is not connected."); + if (IsConnected) throw new MqttProtocolViolationException(message); } private async Task DisconnectInternalAsync() { + var cts = _cancellationTokenSource; + if (cts == null || cts.IsCancellationRequested) + { + return; + } + + cts.Cancel(false); + cts.Dispose(); + _cancellationTokenSource = null; + try { await _adapter.DisconnectAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + MqttTrace.Information(nameof(MqttClient), "Disconnected from adapter."); } catch (Exception exception) { - MqttTrace.Warning(nameof(MqttClient), exception, "Error while disconnecting."); + MqttTrace.Warning(nameof(MqttClient), exception, "Error while disconnecting from adapter."); } finally { - _cancellationTokenSource?.Cancel(false); - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; - - IsConnected = false; - - if (!_disconnectedEventSuspended) - { - _disconnectedEventSuspended = true; - Disconnected?.Invoke(this, EventArgs.Empty); - } + Disconnected?.Invoke(this, EventArgs.Empty); } } - private async Task ProcessReceivedPacketAsync(MqttBasePacket mqttPacket) + private async Task ProcessReceivedPacketAsync(MqttBasePacket packet) { try { - if (mqttPacket is MqttPingReqPacket) + MqttTrace.Information(nameof(MqttClient), "Received <<< {0}", packet); + + if (packet is MqttPingReqPacket) { await SendAsync(new MqttPingRespPacket()); return; } - if (mqttPacket is MqttDisconnectPacket) + if (packet is MqttDisconnectPacket) { await DisconnectAsync(); return; } - if (mqttPacket is MqttPublishPacket publishPacket) + if (packet is MqttPublishPacket publishPacket) { await ProcessReceivedPublishPacket(publishPacket); return; } - if (mqttPacket is MqttPubRelPacket pubRelPacket) + if (packet is MqttPubRelPacket pubRelPacket) { await ProcessReceivedPubRelPacket(pubRelPacket); return; } - _packetDispatcher.Dispatch(mqttPacket); + _packetDispatcher.Dispatch(packet); } catch (Exception exception) { - MqttTrace.Error(nameof(MqttClient), exception, "Error while processing received packet."); + MqttTrace.Error(nameof(MqttClient), exception, "Unhandled exception while processing received packet."); } } private void FireApplicationMessageReceivedEvent(MqttPublishPacket publishPacket) { - var applicationMessage = publishPacket.ToApplicationMessage(); - try { + var applicationMessage = publishPacket.ToApplicationMessage(); ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(applicationMessage)); } catch (Exception exception) @@ -335,13 +314,13 @@ namespace MQTTnet.Core.Client private Task SendAsync(MqttBasePacket packet) { - return _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, packet); + return _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, packet); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket { var wait = _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.DefaultCommunicationTimeout); - await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, requestPacket).ConfigureAwait(false); + await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, requestPacket).ConfigureAwait(false); return (TResponsePacket)await wait.ConfigureAwait(false); } @@ -359,17 +338,25 @@ namespace MQTTnet.Core.Client while (!cancellationToken.IsCancellationRequested) { await Task.Delay(_options.KeepAlivePeriod, cancellationToken).ConfigureAwait(false); + if (cancellationToken.IsCancellationRequested) + { + return; + } + await SendAndReceiveAsync(new MqttPingReqPacket()).ConfigureAwait(false); } } + catch (TaskCanceledException) + { + } catch (MqttCommunicationException exception) { - MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication error while receiving packets."); + MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication exception while sending/receiving keep alive packets."); await DisconnectInternalAsync().ConfigureAwait(false); } catch (Exception exception) { - MqttTrace.Warning(nameof(MqttClient), exception, "Error while sending/receiving keep alive packets."); + MqttTrace.Warning(nameof(MqttClient), exception, "Unhandled exception while sending/receiving keep alive packets."); await DisconnectInternalAsync().ConfigureAwait(false); } finally @@ -381,16 +368,23 @@ namespace MQTTnet.Core.Client private async Task ReceivePackets(CancellationToken cancellationToken) { MqttTrace.Information(nameof(MqttClient), "Start receiving packets."); + try { while (!cancellationToken.IsCancellationRequested) { - var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero).ConfigureAwait(false); - MqttTrace.Information(nameof(MqttClient), "Received <<< {0}", packet); + var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); + if (cancellationToken.IsCancellationRequested) + { + return; + } StartProcessReceivedPacket(packet, cancellationToken); } } + catch (TaskCanceledException) + { + } catch (MqttCommunicationException exception) { MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication exception while receiving packets."); @@ -410,21 +404,21 @@ namespace MQTTnet.Core.Client private void StartProcessReceivedPacket(MqttBasePacket packet, CancellationToken cancellationToken) { #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(() => ProcessReceivedPacketAsync(packet), cancellationToken); + Task.Run(async () => await ProcessReceivedPacketAsync(packet), cancellationToken).ConfigureAwait(false); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } - private void StartReceivePackets() + private void StartReceivePackets(CancellationToken cancellationToken) { #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(() => ReceivePackets(_cancellationTokenSource.Token), _cancellationTokenSource.Token); + Task.Run(async () => await ReceivePackets(cancellationToken), cancellationToken).ConfigureAwait(false); ; #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } - private void StartSendKeepAliveMessages() + private void StartSendKeepAliveMessages(CancellationToken cancellationToken) { #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - Task.Run(() => SendKeepAliveMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); + Task.Run(async () => await SendKeepAliveMessagesAsync(cancellationToken), cancellationToken).ConfigureAwait(false); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } } diff --git a/MQTTnet.Core/Client/MqttClientExtensions.cs b/MQTTnet.Core/Client/MqttClientExtensions.cs index 5e9875a..f664dbd 100644 --- a/MQTTnet.Core/Client/MqttClientExtensions.cs +++ b/MQTTnet.Core/Client/MqttClientExtensions.cs @@ -1,4 +1,8 @@ -using System.Threading.Tasks; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using MQTTnet.Core.Packets; namespace MQTTnet.Core.Client { @@ -6,7 +10,26 @@ namespace MQTTnet.Core.Client { public static Task PublishAsync(this IMqttClient client, params MqttApplicationMessage[] applicationMessages) { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); + return client.PublishAsync(applicationMessages); } + + public static Task> SubscribeAsync(this IMqttClient client, params TopicFilter[] topicFilters) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + return client.SubscribeAsync(topicFilters.ToList()); + } + + public static Task UnsubscribeAsync(this IMqttClient client, params string[] topicFilters) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + return client.UnsubscribeAsync(topicFilters.ToList()); + } } } \ No newline at end of file diff --git a/MQTTnet.Core/Server/MqttClientMessageQueue.cs b/MQTTnet.Core/Server/MqttClientMessageQueue.cs index 0c0b6da..acff44f 100644 --- a/MQTTnet.Core/Server/MqttClientMessageQueue.cs +++ b/MQTTnet.Core/Server/MqttClientMessageQueue.cs @@ -63,7 +63,7 @@ namespace MQTTnet.Core.Server var packets = consumable.Take(_pendingPublishPackets.Count).ToList(); try { - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, packets).ConfigureAwait(false); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, packets).ConfigureAwait(false); } catch (MqttCommunicationException exception) { diff --git a/MQTTnet.Core/Server/MqttClientSession.cs b/MQTTnet.Core/Server/MqttClientSession.cs index 1411dd3..4836396 100644 --- a/MQTTnet.Core/Server/MqttClientSession.cs +++ b/MQTTnet.Core/Server/MqttClientSession.cs @@ -54,7 +54,7 @@ namespace MQTTnet.Core.Server _messageQueue.Start(adapter); while (!_cancellationTokenSource.IsCancellationRequested) { - var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero).ConfigureAwait(false); + var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false); await HandleIncomingPacketAsync(packet).ConfigureAwait(false); } } @@ -103,12 +103,12 @@ namespace MQTTnet.Core.Server { if (packet is MqttSubscribePacket subscribePacket) { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _subscriptionsManager.Subscribe(subscribePacket)); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Subscribe(subscribePacket)); } if (packet is MqttUnsubscribePacket unsubscribePacket) { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _subscriptionsManager.Unsubscribe(unsubscribePacket)); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Unsubscribe(unsubscribePacket)); } if (packet is MqttPublishPacket publishPacket) @@ -123,7 +123,7 @@ namespace MQTTnet.Core.Server if (packet is MqttPubRecPacket pubRecPacket) { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, pubRecPacket.CreateResponse()); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, pubRecPacket.CreateResponse()); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) @@ -134,7 +134,7 @@ namespace MQTTnet.Core.Server if (packet is MqttPingReqPacket) { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPingRespPacket()); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPingRespPacket()); } if (packet is MqttDisconnectPacket || packet is MqttConnectPacket) @@ -160,7 +160,7 @@ namespace MQTTnet.Core.Server if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) { _publishPacketReceivedCallback(this, publishPacket); - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) @@ -173,7 +173,7 @@ namespace MQTTnet.Core.Server _publishPacketReceivedCallback(this, publishPacket); - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); } throw new MqttCommunicationException("Received a not supported QoS level."); @@ -186,7 +186,7 @@ namespace MQTTnet.Core.Server _unacknowledgedPublishPackets.Remove(pubRelPacket.PacketIdentifier); } - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }); } } } diff --git a/MQTTnet.Core/Server/MqttClientSessionsManager.cs b/MQTTnet.Core/Server/MqttClientSessionsManager.cs index f81b8bb..31688c9 100644 --- a/MQTTnet.Core/Server/MqttClientSessionsManager.cs +++ b/MQTTnet.Core/Server/MqttClientSessionsManager.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Adapter; using MQTTnet.Core.Diagnostics; @@ -28,8 +29,7 @@ namespace MQTTnet.Core.Server { try { - var connectPacket = await eventArgs.ClientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false) as MqttConnectPacket; - if (connectPacket == null) + if (!(await eventArgs.ClientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false) is MqttConnectPacket connectPacket)) { throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1]."); } @@ -40,7 +40,7 @@ namespace MQTTnet.Core.Server var connectReturnCode = ValidateConnection(connectPacket); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttConnAckPacket + await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, CancellationToken.None, new MqttConnAckPacket { ConnectReturnCode = connectReturnCode }).ConfigureAwait(false); @@ -50,7 +50,7 @@ namespace MQTTnet.Core.Server var clientSession = GetOrCreateClientSession(connectPacket); - await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttConnAckPacket + await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, CancellationToken.None, new MqttConnAckPacket { ConnectReturnCode = connectReturnCode, IsSessionPresent = clientSession.IsExistingSession diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index 30a84f4..b827326 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -11,23 +11,61 @@ namespace MQTTnet.Core.Tests { [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] - public async Task TestTimeoutAfter() + public async Task TimeoutAfter() { await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); } [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] - public async Task TestTimeoutAfterWithResult() + public async Task TimeoutAfterWithResult() { await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); } [TestMethod] - public async Task TestTimeoutAfterCompleteInTime() + public async Task TimeoutAfterCompleteInTime() { var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500)); Assert.AreEqual(5, result); } + + [TestMethod] + public async Task TimeoutAfterWithInnerException() + { + try + { + await Task.Run(() => + { + var iis = new int[0]; + iis[1] = 0; + }).TimeoutAfter(TimeSpan.FromSeconds(1)); + + Assert.Fail(); + } + catch (MqttCommunicationException e) + { + Assert.IsTrue(e.InnerException is IndexOutOfRangeException); + } + } + + [TestMethod] + public async Task TimeoutAfterWithInnerExceptionWithResult() + { + try + { + var r = await Task.Run(() => + { + var iis = new int[0]; + return iis[1]; + }).TimeoutAfter(TimeSpan.FromSeconds(1)); + + Assert.Fail(); + } + catch (MqttCommunicationException e) + { + Assert.IsTrue(e.InnerException is IndexOutOfRangeException); + } + } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index d97949f..eb7cb83 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -1,6 +1,7 @@ using System; using System.IO; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Core.Adapter; @@ -442,7 +443,7 @@ namespace MQTTnet.Core.Tests using (var headerStream = new MemoryStream(buffer1)) { - var header = MqttPacketReader.ReadHeaderFromSource(headerStream); + var header = MqttPacketReader.ReadHeaderFromSource(headerStream, CancellationToken.None); using (var bodyStream = new MemoryStream(buffer1, (int)headerStream.Position, header.BodyLength)) { diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 9ee2290..d3dc7f1 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -91,7 +91,7 @@ namespace MQTTnet.Core.Tests await Task.Delay(500); Assert.AreEqual(1, receivedMessagesCount); - await c1.Unsubscribe("a"); + await c1.UnsubscribeAsync("a"); await c2.PublishAsync(message); await Task.Delay(500); @@ -158,7 +158,7 @@ namespace MQTTnet.Core.Tests await c2.PublishAsync(new MqttApplicationMessage(topic, new byte[0], qualityOfServiceLevel, false)); await Task.Delay(500); - await c1.Unsubscribe(topicFilter); + await c1.UnsubscribeAsync(topicFilter); await Task.Delay(500); diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index 281d7fd..91d0719 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -28,7 +28,7 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets) + public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) { ThrowIfPartnerIsNull(); @@ -40,16 +40,11 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task ReceivePacketAsync(TimeSpan timeout) + public Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfPartnerIsNull(); - return Task.Run(() => _incomingPackets.Take()); - } - - public IEnumerable ReceivePackets(CancellationToken cancellationToken) - { - return _incomingPackets.GetConsumingEnumerable(); + return Task.Run(() => _incomingPackets.Take(), cancellationToken); } private void SendPacketInternal(MqttBasePacket packet)