Browse Source

Refactor the task timeout code to avoid still running tasks if the timeout is reached.

release/3.x.x
Christian 6 years ago
parent
commit
109e4b2cf1
5 changed files with 86 additions and 98 deletions
  1. +28
    -5
      Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs
  2. +9
    -27
      Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs
  3. +6
    -7
      Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs
  4. +25
    -45
      Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs
  5. +18
    -14
      Tests/MQTTnet.Core.Tests/ExtensionTests.cs

+ 28
- 5
Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs View File

@@ -1,9 +1,9 @@
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Client; using MQTTnet.Client;
using MQTTnet.Internal;
using MQTTnet.Protocol; using MQTTnet.Protocol;


namespace MQTTnet.Extensions.Rpc namespace MQTTnet.Extensions.Rpc
@@ -12,7 +12,7 @@ namespace MQTTnet.Extensions.Rpc
{ {
private readonly ConcurrentDictionary<string, TaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, TaskCompletionSource<byte[]>>(); private readonly ConcurrentDictionary<string, TaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, TaskCompletionSource<byte[]>>();
private readonly IMqttClient _mqttClient; private readonly IMqttClient _mqttClient;
public MqttRpcClient(IMqttClient mqttClient) public MqttRpcClient(IMqttClient mqttClient)
{ {
_mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient)); _mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient));
@@ -22,10 +22,20 @@ namespace MQTTnet.Extensions.Rpc


public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
{ {
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel);
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None);
}

public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken)
{
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken);
}

public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
{
return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None);
} }


public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken)
{ {
if (methodName == null) throw new ArgumentNullException(nameof(methodName)); if (methodName == null) throw new ArgumentNullException(nameof(methodName));


@@ -54,7 +64,20 @@ namespace MQTTnet.Extensions.Rpc
await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false); await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false);
await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false); await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false);


return await tcs.Task.TimeoutAfter(timeout);
using (var timeoutCts = new CancellationTokenSource(timeout))
{
timeoutCts.Token.Register(() =>
{
if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled)
{
tcs.TrySetCanceled();
}
});

var result = await tcs.Task.ConfigureAwait(false);
timeoutCts.Cancel(false);
return result;
}
} }
finally finally
{ {


+ 9
- 27
Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs View File

@@ -7,7 +7,6 @@ using System.Threading.Tasks;
using MQTTnet.Channel; using MQTTnet.Channel;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Exceptions; using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets; using MQTTnet.Packets;
using MQTTnet.Serializer; using MQTTnet.Serializer;


@@ -18,11 +17,11 @@ namespace MQTTnet.Adapter
private const uint ErrorOperationAborted = 0x800703E3; private const uint ErrorOperationAborted = 0x800703E3;
private const int ReadBufferSize = 4096; // TODO: Move buffer size to config private const int ReadBufferSize = 4096; // TODO: Move buffer size to config


private bool _isDisposed;
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
private readonly IMqttNetLogger _logger; private readonly IMqttNetLogger _logger;
private readonly IMqttChannel _channel; private readonly IMqttChannel _channel;


private bool _isDisposed;

public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger) public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger)
{ {
_logger = logger ?? throw new ArgumentNullException(nameof(logger)); _logger = logger ?? throw new ArgumentNullException(nameof(logger));
@@ -37,7 +36,8 @@ namespace MQTTnet.Adapter
ThrowIfDisposed(); ThrowIfDisposed();
_logger.Verbose<MqttChannelAdapter>("Connecting [Timeout={0}]", timeout); _logger.Verbose<MqttChannelAdapter>("Connecting [Timeout={0}]", timeout);


return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync(cancellationToken).TimeoutAfter(timeout));
return ExecuteAndWrapExceptionAsync(() =>
Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken));
} }


public Task DisconnectAsync(TimeSpan timeout) public Task DisconnectAsync(TimeSpan timeout)
@@ -45,7 +45,8 @@ namespace MQTTnet.Adapter
ThrowIfDisposed(); ThrowIfDisposed();
_logger.Verbose<MqttChannelAdapter>("Disconnecting [Timeout={0}]", timeout); _logger.Verbose<MqttChannelAdapter>("Disconnecting [Timeout={0}]", timeout);


return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout));
return ExecuteAndWrapExceptionAsync(() =>
Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, CancellationToken.None));
} }


public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets)
@@ -71,11 +72,11 @@ namespace MQTTnet.Adapter


var packetData = PacketSerializer.Serialize(packet); var packetData = PacketSerializer.Serialize(packet);


return _channel.WriteAsync(
return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync(
packetData.Array, packetData.Array,
packetData.Offset, packetData.Offset,
packetData.Count, packetData.Count,
cancellationToken);
ct), timeout, cancellationToken);
}); });
} }


@@ -91,25 +92,7 @@ namespace MQTTnet.Adapter
{ {
if (timeout > TimeSpan.Zero) if (timeout > TimeSpan.Zero)
{ {
var timeoutCts = new CancellationTokenSource(timeout);
var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);

try
{
receivedMqttPacket = await ReceiveAsync(_channel, linkedCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException exception)
{
var timedOut = linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested;
if (timedOut)
{
throw new MqttCommunicationTimedOutException(exception);
}
else
{
throw;
}
}
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false);
} }
else else
{ {
@@ -232,7 +215,6 @@ namespace MQTTnet.Adapter
public void Dispose() public void Dispose()
{ {
_isDisposed = true; _isDisposed = true;
_semaphore?.Dispose();
_channel?.Dispose(); _channel?.Dispose();
} }




+ 6
- 7
Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs View File

@@ -1,16 +1,15 @@
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Exceptions; using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets; using MQTTnet.Packets;


namespace MQTTnet.Client namespace MQTTnet.Client
{ {
public class MqttPacketDispatcher public class MqttPacketDispatcher
{ {
private readonly ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>> _awaiters = new ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>>(); private readonly ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>> _awaiters = new ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>>();
private readonly IMqttNetLogger _logger; private readonly IMqttNetLogger _logger;


@@ -23,8 +22,8 @@ namespace MQTTnet.Client
{ {
var packetAwaiter = AddPacketAwaiter(responseType, identifier); var packetAwaiter = AddPacketAwaiter(responseType, identifier);
try try
{
return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false);
{
return await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, timeout, CancellationToken.None).ConfigureAwait(false);
} }
catch (MqttCommunicationTimedOutException) catch (MqttCommunicationTimedOutException)
{ {
@@ -49,7 +48,7 @@ namespace MQTTnet.Client


var type = packet.GetType(); var type = packet.GetType();
var key = new Tuple<ushort?, Type>(identifier, type); var key = new Tuple<ushort?, Type>(identifier, type);


if (_awaiters.TryRemove(key, out var tcs)) if (_awaiters.TryRemove(key, out var tcs))
{ {
@@ -74,8 +73,8 @@ namespace MQTTnet.Client
identifier = 0; identifier = 0;
} }


var dictionaryKey = new Tuple<ushort?,Type>(identifier, responseType);
if (!_awaiters.TryAdd(dictionaryKey,tcs))
var dictionaryKey = new Tuple<ushort?, Type>(identifier, responseType);
if (!_awaiters.TryAdd(dictionaryKey, tcs))
{ {
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}."); throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}.");
} }


+ 25
- 45
Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs View File

@@ -7,72 +7,52 @@ namespace MQTTnet.Internal
{ {
public static class TaskExtensions public static class TaskExtensions
{ {
public static async Task TimeoutAfter(this Task task, TimeSpan timeout)
public static async Task TimeoutAfter(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken)
{ {
if (task == null) throw new ArgumentNullException(nameof(task));
if (action == null) throw new ArgumentNullException(nameof(action));


using (var timeoutCts = new CancellationTokenSource())
using (var timeoutCts = new CancellationTokenSource(timeout))
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken))
{ {
try try
{ {
var timeoutTask = Task.Delay(timeout, timeoutCts.Token);
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false);
if (finishedTask == timeoutTask)
{
throw new MqttCommunicationTimedOutException();
}

if (task.IsCanceled)
await action(linkedCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException exception)
{
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested;
if (timeoutReached)
{ {
throw new TaskCanceledException();
throw new MqttCommunicationTimedOutException(exception);
} }


if (task.IsFaulted)
{
throw new MqttCommunicationException(task.Exception?.GetBaseException());
}
}
finally
{
timeoutCts.Cancel();
throw;
} }
} }
} }


public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout)
public static async Task<TResult> TimeoutAfter<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken)
{ {
if (task == null) throw new ArgumentNullException(nameof(task));
if (action == null) throw new ArgumentNullException(nameof(action));


using (var timeoutCts = new CancellationTokenSource())
using (var timeoutCts = new CancellationTokenSource(timeout))
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken))
{ {
try try
{ {
var timeoutTask = Task.Delay(timeout, timeoutCts.Token);
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false);

if (finishedTask == timeoutTask)
{
throw new MqttCommunicationTimedOutException();
}

if (task.IsCanceled)
{
throw new TaskCanceledException();
}

if (task.IsFaulted)
return await action(linkedCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException exception)
{
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested;
if (timeoutReached)
{ {
throw new MqttCommunicationException(task.Exception.GetBaseException());
throw new MqttCommunicationTimedOutException(exception);
} }


return task.Result;
}
finally
{
timeoutCts.Cancel();
throw;
} }
} }
}
}
} }
} }

+ 18
- 14
Tests/MQTTnet.Core.Tests/ExtensionTests.cs View File

@@ -1,9 +1,9 @@
using System; using System;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Exceptions; using MQTTnet.Exceptions;
using MQTTnet.Internal;


namespace MQTTnet.Core.Tests namespace MQTTnet.Core.Tests
{ {
@@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests
[TestMethod] [TestMethod]
public async Task TimeoutAfter() public async Task TimeoutAfter()
{ {
await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100));
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
} }


[ExpectedException(typeof(MqttCommunicationTimedOutException))] [ExpectedException(typeof(MqttCommunicationTimedOutException))]
[TestMethod] [TestMethod]
public async Task TimeoutAfterWithResult() public async Task TimeoutAfterWithResult()
{ {
await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100));
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
} }


[TestMethod] [TestMethod]
public async Task TimeoutAfterCompleteInTime() public async Task TimeoutAfterCompleteInTime()
{ {
var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500));
var result = await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None);
Assert.AreEqual(5, result); Assert.AreEqual(5, result);
} }


@@ -36,17 +36,17 @@ namespace MQTTnet.Core.Tests
{ {
try try
{ {
await Task.Run(() =>
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() =>
{ {
var iis = new int[0]; var iis = new int[0];
iis[1] = 0; iis[1] = 0;
}).TimeoutAfter(TimeSpan.FromSeconds(1));
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None);


Assert.Fail(); Assert.Fail();
} }
catch (MqttCommunicationException e)
catch (Exception e)
{ {
Assert.IsTrue(e.InnerException is IndexOutOfRangeException);
Assert.IsTrue(e is IndexOutOfRangeException);
} }
} }


@@ -55,17 +55,18 @@ namespace MQTTnet.Core.Tests
{ {
try try
{ {
await Task.Run(() =>
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() =>
{ {
var iis = new int[0]; var iis = new int[0];
return iis[1];
}).TimeoutAfter(TimeSpan.FromSeconds(1));
iis[1] = 0;
return iis[0];
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None);


Assert.Fail(); Assert.Fail();
} }
catch (MqttCommunicationException e)
catch (Exception e)
{ {
Assert.IsTrue(e.InnerException is IndexOutOfRangeException);
Assert.IsTrue(e is IndexOutOfRangeException);
} }
} }


@@ -73,7 +74,10 @@ namespace MQTTnet.Core.Tests
public async Task TimeoutAfterMemoryUsage() public async Task TimeoutAfterMemoryUsage()
{ {
var tasks = Enumerable.Range(0, 100000) var tasks = Enumerable.Range(0, 100000)
.Select(i => Task.Delay(TimeSpan.FromMilliseconds(1)).TimeoutAfter(TimeSpan.FromMinutes(1)));
.Select(i =>
{
return Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None);
});


await Task.WhenAll(tasks); await Task.WhenAll(tasks);
AssertIsLess(3_000_000, GC.GetTotalMemory(true)); AssertIsLess(3_000_000, GC.GetTotalMemory(true));


Loading…
Cancel
Save