From 109e4b2cf13d0698e534da84e59539731e8b4c04 Mon Sep 17 00:00:00 2001 From: Christian Date: Wed, 18 Apr 2018 21:14:31 +0200 Subject: [PATCH] Refactor the task timeout code to avoid still running tasks if the timeout is reached. --- .../MQTTnet.Extensions.Rpc/MqttRpcClient.cs | 33 +++++++-- .../Adapter/MqttChannelAdapter.cs | 36 +++------- .../Client/MqttPacketDispatcher.cs | 13 ++-- .../Internal/TaskExtensions.cs | 70 +++++++------------ Tests/MQTTnet.Core.Tests/ExtensionTests.cs | 32 +++++---- 5 files changed, 86 insertions(+), 98 deletions(-) diff --git a/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs b/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs index 78836ea..b8cfd47 100644 --- a/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs +++ b/Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs @@ -1,9 +1,9 @@ using System; using System.Collections.Concurrent; using System.Text; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Client; -using MQTTnet.Internal; using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc @@ -12,7 +12,7 @@ namespace MQTTnet.Extensions.Rpc { private readonly ConcurrentDictionary> _waitingCalls = new ConcurrentDictionary>(); private readonly IMqttClient _mqttClient; - + public MqttRpcClient(IMqttClient mqttClient) { _mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient)); @@ -22,10 +22,20 @@ namespace MQTTnet.Extensions.Rpc public Task ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) { - return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel); + return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None); + } + + public Task ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) + { + return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken); + } + + public Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) + { + return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None); } - public async Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) + public async Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) { if (methodName == null) throw new ArgumentNullException(nameof(methodName)); @@ -54,7 +64,20 @@ namespace MQTTnet.Extensions.Rpc await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false); await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false); - return await tcs.Task.TimeoutAfter(timeout); + using (var timeoutCts = new CancellationTokenSource(timeout)) + { + timeoutCts.Token.Register(() => + { + if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled) + { + tcs.TrySetCanceled(); + } + }); + + var result = await tcs.Task.ConfigureAwait(false); + timeoutCts.Cancel(false); + return result; + } } finally { diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index 1d48f23..79192c4 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -7,7 +7,6 @@ using System.Threading.Tasks; using MQTTnet.Channel; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; -using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Serializer; @@ -18,11 +17,11 @@ namespace MQTTnet.Adapter private const uint ErrorOperationAborted = 0x800703E3; private const int ReadBufferSize = 4096; // TODO: Move buffer size to config - private bool _isDisposed; - private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly IMqttNetLogger _logger; private readonly IMqttChannel _channel; + private bool _isDisposed; + public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); @@ -37,7 +36,8 @@ namespace MQTTnet.Adapter ThrowIfDisposed(); _logger.Verbose("Connecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync(cancellationToken).TimeoutAfter(timeout)); + return ExecuteAndWrapExceptionAsync(() => + Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); } public Task DisconnectAsync(TimeSpan timeout) @@ -45,7 +45,8 @@ namespace MQTTnet.Adapter ThrowIfDisposed(); _logger.Verbose("Disconnecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); + return ExecuteAndWrapExceptionAsync(() => + Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, CancellationToken.None)); } public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) @@ -71,11 +72,11 @@ namespace MQTTnet.Adapter var packetData = PacketSerializer.Serialize(packet); - return _channel.WriteAsync( + return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( packetData.Array, packetData.Offset, packetData.Count, - cancellationToken); + ct), timeout, cancellationToken); }); } @@ -91,25 +92,7 @@ namespace MQTTnet.Adapter { if (timeout > TimeSpan.Zero) { - var timeoutCts = new CancellationTokenSource(timeout); - var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); - - try - { - receivedMqttPacket = await ReceiveAsync(_channel, linkedCts.Token).ConfigureAwait(false); - } - catch (OperationCanceledException exception) - { - var timedOut = linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; - if (timedOut) - { - throw new MqttCommunicationTimedOutException(exception); - } - else - { - throw; - } - } + receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); } else { @@ -232,7 +215,6 @@ namespace MQTTnet.Adapter public void Dispose() { _isDisposed = true; - _semaphore?.Dispose(); _channel?.Dispose(); } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs index 172e53b..5035a1e 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs @@ -1,16 +1,15 @@ using System; using System.Collections.Concurrent; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; -using MQTTnet.Internal; using MQTTnet.Packets; namespace MQTTnet.Client { public class MqttPacketDispatcher { - private readonly ConcurrentDictionary, TaskCompletionSource> _awaiters = new ConcurrentDictionary, TaskCompletionSource>(); private readonly IMqttNetLogger _logger; @@ -23,8 +22,8 @@ namespace MQTTnet.Client { var packetAwaiter = AddPacketAwaiter(responseType, identifier); try - { - return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false); + { + return await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, timeout, CancellationToken.None).ConfigureAwait(false); } catch (MqttCommunicationTimedOutException) { @@ -49,7 +48,7 @@ namespace MQTTnet.Client var type = packet.GetType(); var key = new Tuple(identifier, type); - + if (_awaiters.TryRemove(key, out var tcs)) { @@ -74,8 +73,8 @@ namespace MQTTnet.Client identifier = 0; } - var dictionaryKey = new Tuple(identifier, responseType); - if (!_awaiters.TryAdd(dictionaryKey,tcs)) + var dictionaryKey = new Tuple(identifier, responseType); + if (!_awaiters.TryAdd(dictionaryKey, tcs)) { throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}."); } diff --git a/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs b/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs index 876c549..288ac0b 100644 --- a/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs @@ -7,72 +7,52 @@ namespace MQTTnet.Internal { public static class TaskExtensions { - public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + public static async Task TimeoutAfter(Func action, TimeSpan timeout, CancellationToken cancellationToken) { - if (task == null) throw new ArgumentNullException(nameof(task)); + if (action == null) throw new ArgumentNullException(nameof(action)); - using (var timeoutCts = new CancellationTokenSource()) + using (var timeoutCts = new CancellationTokenSource(timeout)) + using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) { try { - var timeoutTask = Task.Delay(timeout, timeoutCts.Token); - var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - - if (finishedTask == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsCanceled) + await action(linkedCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException exception) + { + var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; + if (timeoutReached) { - throw new TaskCanceledException(); + throw new MqttCommunicationTimedOutException(exception); } - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception?.GetBaseException()); - } - } - finally - { - timeoutCts.Cancel(); + throw; } } } - public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + public static async Task TimeoutAfter(Func> action, TimeSpan timeout, CancellationToken cancellationToken) { - if (task == null) throw new ArgumentNullException(nameof(task)); + if (action == null) throw new ArgumentNullException(nameof(action)); - using (var timeoutCts = new CancellationTokenSource()) + using (var timeoutCts = new CancellationTokenSource(timeout)) + using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) { try { - var timeoutTask = Task.Delay(timeout, timeoutCts.Token); - var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - - if (finishedTask == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsCanceled) - { - throw new TaskCanceledException(); - } - - if (task.IsFaulted) + return await action(linkedCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException exception) + { + var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; + if (timeoutReached) { - throw new MqttCommunicationException(task.Exception.GetBaseException()); + throw new MqttCommunicationTimedOutException(exception); } - return task.Result; - } - finally - { - timeoutCts.Cancel(); + throw; } } - } + } } } diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index 053b0ca..daeea62 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -1,9 +1,9 @@ using System; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Exceptions; -using MQTTnet.Internal; namespace MQTTnet.Core.Tests { @@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests [TestMethod] public async Task TimeoutAfter() { - await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [ExpectedException(typeof(MqttCommunicationTimedOutException))] [TestMethod] public async Task TimeoutAfterWithResult() { - await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); } [TestMethod] public async Task TimeoutAfterCompleteInTime() { - var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500)); + var result = await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); Assert.AreEqual(5, result); } @@ -36,17 +36,17 @@ namespace MQTTnet.Core.Tests { try { - await Task.Run(() => + await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => { var iis = new int[0]; iis[1] = 0; - }).TimeoutAfter(TimeSpan.FromSeconds(1)); + }, ct), TimeSpan.FromSeconds(1), CancellationToken.None); Assert.Fail(); } - catch (MqttCommunicationException e) + catch (Exception e) { - Assert.IsTrue(e.InnerException is IndexOutOfRangeException); + Assert.IsTrue(e is IndexOutOfRangeException); } } @@ -55,17 +55,18 @@ namespace MQTTnet.Core.Tests { try { - await Task.Run(() => + await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => { var iis = new int[0]; - return iis[1]; - }).TimeoutAfter(TimeSpan.FromSeconds(1)); + iis[1] = 0; + return iis[0]; + }, ct), TimeSpan.FromSeconds(1), CancellationToken.None); Assert.Fail(); } - catch (MqttCommunicationException e) + catch (Exception e) { - Assert.IsTrue(e.InnerException is IndexOutOfRangeException); + Assert.IsTrue(e is IndexOutOfRangeException); } } @@ -73,7 +74,10 @@ namespace MQTTnet.Core.Tests public async Task TimeoutAfterMemoryUsage() { var tasks = Enumerable.Range(0, 100000) - .Select(i => Task.Delay(TimeSpan.FromMilliseconds(1)).TimeoutAfter(TimeSpan.FromMinutes(1))); + .Select(i => + { + return Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); + }); await Task.WhenAll(tasks); AssertIsLess(3_000_000, GC.GetTotalMemory(true));