Bladeren bron

Refactor the task timeout code to avoid still running tasks if the timeout is reached.

release/3.x.x
Christian 6 jaren geleden
bovenliggende
commit
109e4b2cf1
5 gewijzigde bestanden met toevoegingen van 86 en 98 verwijderingen
  1. +28
    -5
      Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs
  2. +9
    -27
      Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs
  3. +6
    -7
      Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs
  4. +25
    -45
      Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs
  5. +18
    -14
      Tests/MQTTnet.Core.Tests/ExtensionTests.cs

+ 28
- 5
Extensions/MQTTnet.Extensions.Rpc/MqttRpcClient.cs Bestand weergeven

@@ -1,9 +1,9 @@
using System;
using System.Collections.Concurrent;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Client;
using MQTTnet.Internal;
using MQTTnet.Protocol;

namespace MQTTnet.Extensions.Rpc
@@ -12,7 +12,7 @@ namespace MQTTnet.Extensions.Rpc
{
private readonly ConcurrentDictionary<string, TaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, TaskCompletionSource<byte[]>>();
private readonly IMqttClient _mqttClient;
public MqttRpcClient(IMqttClient mqttClient)
{
_mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient));
@@ -22,10 +22,20 @@ namespace MQTTnet.Extensions.Rpc

public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
{
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel);
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None);
}

public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken)
{
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken);
}

public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
{
return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None);
}

public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken)
{
if (methodName == null) throw new ArgumentNullException(nameof(methodName));

@@ -54,7 +64,20 @@ namespace MQTTnet.Extensions.Rpc
await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false);
await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false);

return await tcs.Task.TimeoutAfter(timeout);
using (var timeoutCts = new CancellationTokenSource(timeout))
{
timeoutCts.Token.Register(() =>
{
if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled)
{
tcs.TrySetCanceled();
}
});

var result = await tcs.Task.ConfigureAwait(false);
timeoutCts.Cancel(false);
return result;
}
}
finally
{


+ 9
- 27
Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs Bestand weergeven

@@ -7,7 +7,6 @@ using System.Threading.Tasks;
using MQTTnet.Channel;
using MQTTnet.Diagnostics;
using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Serializer;

@@ -18,11 +17,11 @@ namespace MQTTnet.Adapter
private const uint ErrorOperationAborted = 0x800703E3;
private const int ReadBufferSize = 4096; // TODO: Move buffer size to config

private bool _isDisposed;
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
private readonly IMqttNetLogger _logger;
private readonly IMqttChannel _channel;

private bool _isDisposed;

public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
@@ -37,7 +36,8 @@ namespace MQTTnet.Adapter
ThrowIfDisposed();
_logger.Verbose<MqttChannelAdapter>("Connecting [Timeout={0}]", timeout);

return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync(cancellationToken).TimeoutAfter(timeout));
return ExecuteAndWrapExceptionAsync(() =>
Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken));
}

public Task DisconnectAsync(TimeSpan timeout)
@@ -45,7 +45,8 @@ namespace MQTTnet.Adapter
ThrowIfDisposed();
_logger.Verbose<MqttChannelAdapter>("Disconnecting [Timeout={0}]", timeout);

return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout));
return ExecuteAndWrapExceptionAsync(() =>
Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, CancellationToken.None));
}

public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets)
@@ -71,11 +72,11 @@ namespace MQTTnet.Adapter

var packetData = PacketSerializer.Serialize(packet);

return _channel.WriteAsync(
return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync(
packetData.Array,
packetData.Offset,
packetData.Count,
cancellationToken);
ct), timeout, cancellationToken);
});
}

@@ -91,25 +92,7 @@ namespace MQTTnet.Adapter
{
if (timeout > TimeSpan.Zero)
{
var timeoutCts = new CancellationTokenSource(timeout);
var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);

try
{
receivedMqttPacket = await ReceiveAsync(_channel, linkedCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException exception)
{
var timedOut = linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested;
if (timedOut)
{
throw new MqttCommunicationTimedOutException(exception);
}
else
{
throw;
}
}
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false);
}
else
{
@@ -232,7 +215,6 @@ namespace MQTTnet.Adapter
public void Dispose()
{
_isDisposed = true;
_semaphore?.Dispose();
_channel?.Dispose();
}



+ 6
- 7
Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs Bestand weergeven

@@ -1,16 +1,15 @@
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Diagnostics;
using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets;

namespace MQTTnet.Client
{
public class MqttPacketDispatcher
{
private readonly ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>> _awaiters = new ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>>();
private readonly IMqttNetLogger _logger;

@@ -23,8 +22,8 @@ namespace MQTTnet.Client
{
var packetAwaiter = AddPacketAwaiter(responseType, identifier);
try
{
return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false);
{
return await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, timeout, CancellationToken.None).ConfigureAwait(false);
}
catch (MqttCommunicationTimedOutException)
{
@@ -49,7 +48,7 @@ namespace MQTTnet.Client

var type = packet.GetType();
var key = new Tuple<ushort?, Type>(identifier, type);

if (_awaiters.TryRemove(key, out var tcs))
{
@@ -74,8 +73,8 @@ namespace MQTTnet.Client
identifier = 0;
}

var dictionaryKey = new Tuple<ushort?,Type>(identifier, responseType);
if (!_awaiters.TryAdd(dictionaryKey,tcs))
var dictionaryKey = new Tuple<ushort?, Type>(identifier, responseType);
if (!_awaiters.TryAdd(dictionaryKey, tcs))
{
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}.");
}


+ 25
- 45
Frameworks/MQTTnet.NetStandard/Internal/TaskExtensions.cs Bestand weergeven

@@ -7,72 +7,52 @@ namespace MQTTnet.Internal
{
public static class TaskExtensions
{
public static async Task TimeoutAfter(this Task task, TimeSpan timeout)
public static async Task TimeoutAfter(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken)
{
if (task == null) throw new ArgumentNullException(nameof(task));
if (action == null) throw new ArgumentNullException(nameof(action));

using (var timeoutCts = new CancellationTokenSource())
using (var timeoutCts = new CancellationTokenSource(timeout))
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken))
{
try
{
var timeoutTask = Task.Delay(timeout, timeoutCts.Token);
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false);
if (finishedTask == timeoutTask)
{
throw new MqttCommunicationTimedOutException();
}

if (task.IsCanceled)
await action(linkedCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException exception)
{
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested;
if (timeoutReached)
{
throw new TaskCanceledException();
throw new MqttCommunicationTimedOutException(exception);
}

if (task.IsFaulted)
{
throw new MqttCommunicationException(task.Exception?.GetBaseException());
}
}
finally
{
timeoutCts.Cancel();
throw;
}
}
}

public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout)
public static async Task<TResult> TimeoutAfter<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken)
{
if (task == null) throw new ArgumentNullException(nameof(task));
if (action == null) throw new ArgumentNullException(nameof(action));

using (var timeoutCts = new CancellationTokenSource())
using (var timeoutCts = new CancellationTokenSource(timeout))
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken))
{
try
{
var timeoutTask = Task.Delay(timeout, timeoutCts.Token);
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false);

if (finishedTask == timeoutTask)
{
throw new MqttCommunicationTimedOutException();
}

if (task.IsCanceled)
{
throw new TaskCanceledException();
}

if (task.IsFaulted)
return await action(linkedCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException exception)
{
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested;
if (timeoutReached)
{
throw new MqttCommunicationException(task.Exception.GetBaseException());
throw new MqttCommunicationTimedOutException(exception);
}

return task.Result;
}
finally
{
timeoutCts.Cancel();
throw;
}
}
}
}
}
}

+ 18
- 14
Tests/MQTTnet.Core.Tests/ExtensionTests.cs Bestand weergeven

@@ -1,9 +1,9 @@
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Exceptions;
using MQTTnet.Internal;

namespace MQTTnet.Core.Tests
{
@@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests
[TestMethod]
public async Task TimeoutAfter()
{
await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100));
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
}

[ExpectedException(typeof(MqttCommunicationTimedOutException))]
[TestMethod]
public async Task TimeoutAfterWithResult()
{
await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100));
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
}

[TestMethod]
public async Task TimeoutAfterCompleteInTime()
{
var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500));
var result = await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None);
Assert.AreEqual(5, result);
}

@@ -36,17 +36,17 @@ namespace MQTTnet.Core.Tests
{
try
{
await Task.Run(() =>
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() =>
{
var iis = new int[0];
iis[1] = 0;
}).TimeoutAfter(TimeSpan.FromSeconds(1));
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None);

Assert.Fail();
}
catch (MqttCommunicationException e)
catch (Exception e)
{
Assert.IsTrue(e.InnerException is IndexOutOfRangeException);
Assert.IsTrue(e is IndexOutOfRangeException);
}
}

@@ -55,17 +55,18 @@ namespace MQTTnet.Core.Tests
{
try
{
await Task.Run(() =>
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() =>
{
var iis = new int[0];
return iis[1];
}).TimeoutAfter(TimeSpan.FromSeconds(1));
iis[1] = 0;
return iis[0];
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None);

Assert.Fail();
}
catch (MqttCommunicationException e)
catch (Exception e)
{
Assert.IsTrue(e.InnerException is IndexOutOfRangeException);
Assert.IsTrue(e is IndexOutOfRangeException);
}
}

@@ -73,7 +74,10 @@ namespace MQTTnet.Core.Tests
public async Task TimeoutAfterMemoryUsage()
{
var tasks = Enumerable.Range(0, 100000)
.Select(i => Task.Delay(TimeSpan.FromMilliseconds(1)).TimeoutAfter(TimeSpan.FromMinutes(1)));
.Select(i =>
{
return Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None);
});

await Task.WhenAll(tasks);
AssertIsLess(3_000_000, GC.GetTotalMemory(true));


Laden…
Annuleren
Opslaan