@@ -1,9 +1,9 @@ | |||
using System; | |||
using System.Collections.Concurrent; | |||
using System.Text; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Client; | |||
using MQTTnet.Internal; | |||
using MQTTnet.Protocol; | |||
namespace MQTTnet.Extensions.Rpc | |||
@@ -12,7 +12,7 @@ namespace MQTTnet.Extensions.Rpc | |||
{ | |||
private readonly ConcurrentDictionary<string, TaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, TaskCompletionSource<byte[]>>(); | |||
private readonly IMqttClient _mqttClient; | |||
public MqttRpcClient(IMqttClient mqttClient) | |||
{ | |||
_mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient)); | |||
@@ -22,10 +22,20 @@ namespace MQTTnet.Extensions.Rpc | |||
public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) | |||
{ | |||
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel); | |||
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None); | |||
} | |||
public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) | |||
{ | |||
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken); | |||
} | |||
public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) | |||
{ | |||
return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None); | |||
} | |||
public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) | |||
public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) | |||
{ | |||
if (methodName == null) throw new ArgumentNullException(nameof(methodName)); | |||
@@ -54,7 +64,20 @@ namespace MQTTnet.Extensions.Rpc | |||
await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false); | |||
await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false); | |||
return await tcs.Task.TimeoutAfter(timeout); | |||
using (var timeoutCts = new CancellationTokenSource(timeout)) | |||
{ | |||
timeoutCts.Token.Register(() => | |||
{ | |||
if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled) | |||
{ | |||
tcs.TrySetCanceled(); | |||
} | |||
}); | |||
var result = await tcs.Task.ConfigureAwait(false); | |||
timeoutCts.Cancel(false); | |||
return result; | |||
} | |||
} | |||
finally | |||
{ | |||
@@ -7,7 +7,6 @@ using System.Threading.Tasks; | |||
using MQTTnet.Channel; | |||
using MQTTnet.Diagnostics; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Internal; | |||
using MQTTnet.Packets; | |||
using MQTTnet.Serializer; | |||
@@ -18,11 +17,11 @@ namespace MQTTnet.Adapter | |||
private const uint ErrorOperationAborted = 0x800703E3; | |||
private const int ReadBufferSize = 4096; // TODO: Move buffer size to config | |||
private bool _isDisposed; | |||
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); | |||
private readonly IMqttNetLogger _logger; | |||
private readonly IMqttChannel _channel; | |||
private bool _isDisposed; | |||
public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger) | |||
{ | |||
_logger = logger ?? throw new ArgumentNullException(nameof(logger)); | |||
@@ -37,7 +36,8 @@ namespace MQTTnet.Adapter | |||
ThrowIfDisposed(); | |||
_logger.Verbose<MqttChannelAdapter>("Connecting [Timeout={0}]", timeout); | |||
return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync(cancellationToken).TimeoutAfter(timeout)); | |||
return ExecuteAndWrapExceptionAsync(() => | |||
Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); | |||
} | |||
public Task DisconnectAsync(TimeSpan timeout) | |||
@@ -45,7 +45,8 @@ namespace MQTTnet.Adapter | |||
ThrowIfDisposed(); | |||
_logger.Verbose<MqttChannelAdapter>("Disconnecting [Timeout={0}]", timeout); | |||
return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); | |||
return ExecuteAndWrapExceptionAsync(() => | |||
Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, CancellationToken.None)); | |||
} | |||
public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) | |||
@@ -71,11 +72,11 @@ namespace MQTTnet.Adapter | |||
var packetData = PacketSerializer.Serialize(packet); | |||
return _channel.WriteAsync( | |||
return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( | |||
packetData.Array, | |||
packetData.Offset, | |||
packetData.Count, | |||
cancellationToken); | |||
ct), timeout, cancellationToken); | |||
}); | |||
} | |||
@@ -91,25 +92,7 @@ namespace MQTTnet.Adapter | |||
{ | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
var timeoutCts = new CancellationTokenSource(timeout); | |||
var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); | |||
try | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel, linkedCts.Token).ConfigureAwait(false); | |||
} | |||
catch (OperationCanceledException exception) | |||
{ | |||
var timedOut = linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; | |||
if (timedOut) | |||
{ | |||
throw new MqttCommunicationTimedOutException(exception); | |||
} | |||
else | |||
{ | |||
throw; | |||
} | |||
} | |||
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
@@ -232,7 +215,6 @@ namespace MQTTnet.Adapter | |||
public void Dispose() | |||
{ | |||
_isDisposed = true; | |||
_semaphore?.Dispose(); | |||
_channel?.Dispose(); | |||
} | |||
@@ -1,16 +1,15 @@ | |||
using System; | |||
using System.Collections.Concurrent; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Diagnostics; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Internal; | |||
using MQTTnet.Packets; | |||
namespace MQTTnet.Client | |||
{ | |||
public class MqttPacketDispatcher | |||
{ | |||
private readonly ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>> _awaiters = new ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>>(); | |||
private readonly IMqttNetLogger _logger; | |||
@@ -23,8 +22,8 @@ namespace MQTTnet.Client | |||
{ | |||
var packetAwaiter = AddPacketAwaiter(responseType, identifier); | |||
try | |||
{ | |||
return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false); | |||
{ | |||
return await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, timeout, CancellationToken.None).ConfigureAwait(false); | |||
} | |||
catch (MqttCommunicationTimedOutException) | |||
{ | |||
@@ -49,7 +48,7 @@ namespace MQTTnet.Client | |||
var type = packet.GetType(); | |||
var key = new Tuple<ushort?, Type>(identifier, type); | |||
if (_awaiters.TryRemove(key, out var tcs)) | |||
{ | |||
@@ -74,8 +73,8 @@ namespace MQTTnet.Client | |||
identifier = 0; | |||
} | |||
var dictionaryKey = new Tuple<ushort?,Type>(identifier, responseType); | |||
if (!_awaiters.TryAdd(dictionaryKey,tcs)) | |||
var dictionaryKey = new Tuple<ushort?, Type>(identifier, responseType); | |||
if (!_awaiters.TryAdd(dictionaryKey, tcs)) | |||
{ | |||
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}."); | |||
} | |||
@@ -7,72 +7,52 @@ namespace MQTTnet.Internal | |||
{ | |||
public static class TaskExtensions | |||
{ | |||
public static async Task TimeoutAfter(this Task task, TimeSpan timeout) | |||
public static async Task TimeoutAfter(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
if (task == null) throw new ArgumentNullException(nameof(task)); | |||
if (action == null) throw new ArgumentNullException(nameof(action)); | |||
using (var timeoutCts = new CancellationTokenSource()) | |||
using (var timeoutCts = new CancellationTokenSource(timeout)) | |||
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) | |||
{ | |||
try | |||
{ | |||
var timeoutTask = Task.Delay(timeout, timeoutCts.Token); | |||
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); | |||
if (finishedTask == timeoutTask) | |||
{ | |||
throw new MqttCommunicationTimedOutException(); | |||
} | |||
if (task.IsCanceled) | |||
await action(linkedCts.Token).ConfigureAwait(false); | |||
} | |||
catch (OperationCanceledException exception) | |||
{ | |||
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; | |||
if (timeoutReached) | |||
{ | |||
throw new TaskCanceledException(); | |||
throw new MqttCommunicationTimedOutException(exception); | |||
} | |||
if (task.IsFaulted) | |||
{ | |||
throw new MqttCommunicationException(task.Exception?.GetBaseException()); | |||
} | |||
} | |||
finally | |||
{ | |||
timeoutCts.Cancel(); | |||
throw; | |||
} | |||
} | |||
} | |||
public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout) | |||
public static async Task<TResult> TimeoutAfter<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
if (task == null) throw new ArgumentNullException(nameof(task)); | |||
if (action == null) throw new ArgumentNullException(nameof(action)); | |||
using (var timeoutCts = new CancellationTokenSource()) | |||
using (var timeoutCts = new CancellationTokenSource(timeout)) | |||
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) | |||
{ | |||
try | |||
{ | |||
var timeoutTask = Task.Delay(timeout, timeoutCts.Token); | |||
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); | |||
if (finishedTask == timeoutTask) | |||
{ | |||
throw new MqttCommunicationTimedOutException(); | |||
} | |||
if (task.IsCanceled) | |||
{ | |||
throw new TaskCanceledException(); | |||
} | |||
if (task.IsFaulted) | |||
return await action(linkedCts.Token).ConfigureAwait(false); | |||
} | |||
catch (OperationCanceledException exception) | |||
{ | |||
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; | |||
if (timeoutReached) | |||
{ | |||
throw new MqttCommunicationException(task.Exception.GetBaseException()); | |||
throw new MqttCommunicationTimedOutException(exception); | |||
} | |||
return task.Result; | |||
} | |||
finally | |||
{ | |||
timeoutCts.Cancel(); | |||
throw; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -1,9 +1,9 @@ | |||
using System; | |||
using System.Linq; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Internal; | |||
namespace MQTTnet.Core.Tests | |||
{ | |||
@@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests | |||
[TestMethod] | |||
public async Task TimeoutAfter() | |||
{ | |||
await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||
} | |||
[ExpectedException(typeof(MqttCommunicationTimedOutException))] | |||
[TestMethod] | |||
public async Task TimeoutAfterWithResult() | |||
{ | |||
await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||
} | |||
[TestMethod] | |||
public async Task TimeoutAfterCompleteInTime() | |||
{ | |||
var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500)); | |||
var result = await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); | |||
Assert.AreEqual(5, result); | |||
} | |||
@@ -36,17 +36,17 @@ namespace MQTTnet.Core.Tests | |||
{ | |||
try | |||
{ | |||
await Task.Run(() => | |||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => | |||
{ | |||
var iis = new int[0]; | |||
iis[1] = 0; | |||
}).TimeoutAfter(TimeSpan.FromSeconds(1)); | |||
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None); | |||
Assert.Fail(); | |||
} | |||
catch (MqttCommunicationException e) | |||
catch (Exception e) | |||
{ | |||
Assert.IsTrue(e.InnerException is IndexOutOfRangeException); | |||
Assert.IsTrue(e is IndexOutOfRangeException); | |||
} | |||
} | |||
@@ -55,17 +55,18 @@ namespace MQTTnet.Core.Tests | |||
{ | |||
try | |||
{ | |||
await Task.Run(() => | |||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => | |||
{ | |||
var iis = new int[0]; | |||
return iis[1]; | |||
}).TimeoutAfter(TimeSpan.FromSeconds(1)); | |||
iis[1] = 0; | |||
return iis[0]; | |||
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None); | |||
Assert.Fail(); | |||
} | |||
catch (MqttCommunicationException e) | |||
catch (Exception e) | |||
{ | |||
Assert.IsTrue(e.InnerException is IndexOutOfRangeException); | |||
Assert.IsTrue(e is IndexOutOfRangeException); | |||
} | |||
} | |||
@@ -73,7 +74,10 @@ namespace MQTTnet.Core.Tests | |||
public async Task TimeoutAfterMemoryUsage() | |||
{ | |||
var tasks = Enumerable.Range(0, 100000) | |||
.Select(i => Task.Delay(TimeSpan.FromMilliseconds(1)).TimeoutAfter(TimeSpan.FromMinutes(1))); | |||
.Select(i => | |||
{ | |||
return Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); | |||
}); | |||
await Task.WhenAll(tasks); | |||
AssertIsLess(3_000_000, GC.GetTotalMemory(true)); | |||