@@ -1,9 +1,9 @@ | |||||
using System; | using System; | ||||
using System.Collections.Concurrent; | using System.Collections.Concurrent; | ||||
using System.Text; | using System.Text; | ||||
using System.Threading; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using MQTTnet.Client; | using MQTTnet.Client; | ||||
using MQTTnet.Internal; | |||||
using MQTTnet.Protocol; | using MQTTnet.Protocol; | ||||
namespace MQTTnet.Extensions.Rpc | namespace MQTTnet.Extensions.Rpc | ||||
@@ -12,7 +12,7 @@ namespace MQTTnet.Extensions.Rpc | |||||
{ | { | ||||
private readonly ConcurrentDictionary<string, TaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, TaskCompletionSource<byte[]>>(); | private readonly ConcurrentDictionary<string, TaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, TaskCompletionSource<byte[]>>(); | ||||
private readonly IMqttClient _mqttClient; | private readonly IMqttClient _mqttClient; | ||||
public MqttRpcClient(IMqttClient mqttClient) | public MqttRpcClient(IMqttClient mqttClient) | ||||
{ | { | ||||
_mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient)); | _mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient)); | ||||
@@ -22,10 +22,20 @@ namespace MQTTnet.Extensions.Rpc | |||||
public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) | public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) | ||||
{ | { | ||||
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel); | |||||
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None); | |||||
} | |||||
public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) | |||||
{ | |||||
return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken); | |||||
} | |||||
public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) | |||||
{ | |||||
return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None); | |||||
} | } | ||||
public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) | |||||
public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) | |||||
{ | { | ||||
if (methodName == null) throw new ArgumentNullException(nameof(methodName)); | if (methodName == null) throw new ArgumentNullException(nameof(methodName)); | ||||
@@ -54,7 +64,20 @@ namespace MQTTnet.Extensions.Rpc | |||||
await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false); | await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false); | ||||
await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false); | await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false); | ||||
return await tcs.Task.TimeoutAfter(timeout); | |||||
using (var timeoutCts = new CancellationTokenSource(timeout)) | |||||
{ | |||||
timeoutCts.Token.Register(() => | |||||
{ | |||||
if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled) | |||||
{ | |||||
tcs.TrySetCanceled(); | |||||
} | |||||
}); | |||||
var result = await tcs.Task.ConfigureAwait(false); | |||||
timeoutCts.Cancel(false); | |||||
return result; | |||||
} | |||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
@@ -7,7 +7,6 @@ using System.Threading.Tasks; | |||||
using MQTTnet.Channel; | using MQTTnet.Channel; | ||||
using MQTTnet.Diagnostics; | using MQTTnet.Diagnostics; | ||||
using MQTTnet.Exceptions; | using MQTTnet.Exceptions; | ||||
using MQTTnet.Internal; | |||||
using MQTTnet.Packets; | using MQTTnet.Packets; | ||||
using MQTTnet.Serializer; | using MQTTnet.Serializer; | ||||
@@ -18,11 +17,11 @@ namespace MQTTnet.Adapter | |||||
private const uint ErrorOperationAborted = 0x800703E3; | private const uint ErrorOperationAborted = 0x800703E3; | ||||
private const int ReadBufferSize = 4096; // TODO: Move buffer size to config | private const int ReadBufferSize = 4096; // TODO: Move buffer size to config | ||||
private bool _isDisposed; | |||||
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); | |||||
private readonly IMqttNetLogger _logger; | private readonly IMqttNetLogger _logger; | ||||
private readonly IMqttChannel _channel; | private readonly IMqttChannel _channel; | ||||
private bool _isDisposed; | |||||
public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger) | public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetLogger logger) | ||||
{ | { | ||||
_logger = logger ?? throw new ArgumentNullException(nameof(logger)); | _logger = logger ?? throw new ArgumentNullException(nameof(logger)); | ||||
@@ -37,7 +36,8 @@ namespace MQTTnet.Adapter | |||||
ThrowIfDisposed(); | ThrowIfDisposed(); | ||||
_logger.Verbose<MqttChannelAdapter>("Connecting [Timeout={0}]", timeout); | _logger.Verbose<MqttChannelAdapter>("Connecting [Timeout={0}]", timeout); | ||||
return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync(cancellationToken).TimeoutAfter(timeout)); | |||||
return ExecuteAndWrapExceptionAsync(() => | |||||
Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); | |||||
} | } | ||||
public Task DisconnectAsync(TimeSpan timeout) | public Task DisconnectAsync(TimeSpan timeout) | ||||
@@ -45,7 +45,8 @@ namespace MQTTnet.Adapter | |||||
ThrowIfDisposed(); | ThrowIfDisposed(); | ||||
_logger.Verbose<MqttChannelAdapter>("Disconnecting [Timeout={0}]", timeout); | _logger.Verbose<MqttChannelAdapter>("Disconnecting [Timeout={0}]", timeout); | ||||
return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); | |||||
return ExecuteAndWrapExceptionAsync(() => | |||||
Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, CancellationToken.None)); | |||||
} | } | ||||
public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) | public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) | ||||
@@ -71,11 +72,11 @@ namespace MQTTnet.Adapter | |||||
var packetData = PacketSerializer.Serialize(packet); | var packetData = PacketSerializer.Serialize(packet); | ||||
return _channel.WriteAsync( | |||||
return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( | |||||
packetData.Array, | packetData.Array, | ||||
packetData.Offset, | packetData.Offset, | ||||
packetData.Count, | packetData.Count, | ||||
cancellationToken); | |||||
ct), timeout, cancellationToken); | |||||
}); | }); | ||||
} | } | ||||
@@ -91,25 +92,7 @@ namespace MQTTnet.Adapter | |||||
{ | { | ||||
if (timeout > TimeSpan.Zero) | if (timeout > TimeSpan.Zero) | ||||
{ | { | ||||
var timeoutCts = new CancellationTokenSource(timeout); | |||||
var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); | |||||
try | |||||
{ | |||||
receivedMqttPacket = await ReceiveAsync(_channel, linkedCts.Token).ConfigureAwait(false); | |||||
} | |||||
catch (OperationCanceledException exception) | |||||
{ | |||||
var timedOut = linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; | |||||
if (timedOut) | |||||
{ | |||||
throw new MqttCommunicationTimedOutException(exception); | |||||
} | |||||
else | |||||
{ | |||||
throw; | |||||
} | |||||
} | |||||
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -232,7 +215,6 @@ namespace MQTTnet.Adapter | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
_isDisposed = true; | _isDisposed = true; | ||||
_semaphore?.Dispose(); | |||||
_channel?.Dispose(); | _channel?.Dispose(); | ||||
} | } | ||||
@@ -1,16 +1,15 @@ | |||||
using System; | using System; | ||||
using System.Collections.Concurrent; | using System.Collections.Concurrent; | ||||
using System.Threading; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using MQTTnet.Diagnostics; | using MQTTnet.Diagnostics; | ||||
using MQTTnet.Exceptions; | using MQTTnet.Exceptions; | ||||
using MQTTnet.Internal; | |||||
using MQTTnet.Packets; | using MQTTnet.Packets; | ||||
namespace MQTTnet.Client | namespace MQTTnet.Client | ||||
{ | { | ||||
public class MqttPacketDispatcher | public class MqttPacketDispatcher | ||||
{ | { | ||||
private readonly ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>> _awaiters = new ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>>(); | private readonly ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>> _awaiters = new ConcurrentDictionary<Tuple<ushort?, Type>, TaskCompletionSource<MqttBasePacket>>(); | ||||
private readonly IMqttNetLogger _logger; | private readonly IMqttNetLogger _logger; | ||||
@@ -23,8 +22,8 @@ namespace MQTTnet.Client | |||||
{ | { | ||||
var packetAwaiter = AddPacketAwaiter(responseType, identifier); | var packetAwaiter = AddPacketAwaiter(responseType, identifier); | ||||
try | try | ||||
{ | |||||
return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false); | |||||
{ | |||||
return await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, timeout, CancellationToken.None).ConfigureAwait(false); | |||||
} | } | ||||
catch (MqttCommunicationTimedOutException) | catch (MqttCommunicationTimedOutException) | ||||
{ | { | ||||
@@ -49,7 +48,7 @@ namespace MQTTnet.Client | |||||
var type = packet.GetType(); | var type = packet.GetType(); | ||||
var key = new Tuple<ushort?, Type>(identifier, type); | var key = new Tuple<ushort?, Type>(identifier, type); | ||||
if (_awaiters.TryRemove(key, out var tcs)) | if (_awaiters.TryRemove(key, out var tcs)) | ||||
{ | { | ||||
@@ -74,8 +73,8 @@ namespace MQTTnet.Client | |||||
identifier = 0; | identifier = 0; | ||||
} | } | ||||
var dictionaryKey = new Tuple<ushort?,Type>(identifier, responseType); | |||||
if (!_awaiters.TryAdd(dictionaryKey,tcs)) | |||||
var dictionaryKey = new Tuple<ushort?, Type>(identifier, responseType); | |||||
if (!_awaiters.TryAdd(dictionaryKey, tcs)) | |||||
{ | { | ||||
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}."); | throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}."); | ||||
} | } | ||||
@@ -7,72 +7,52 @@ namespace MQTTnet.Internal | |||||
{ | { | ||||
public static class TaskExtensions | public static class TaskExtensions | ||||
{ | { | ||||
public static async Task TimeoutAfter(this Task task, TimeSpan timeout) | |||||
public static async Task TimeoutAfter(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken) | |||||
{ | { | ||||
if (task == null) throw new ArgumentNullException(nameof(task)); | |||||
if (action == null) throw new ArgumentNullException(nameof(action)); | |||||
using (var timeoutCts = new CancellationTokenSource()) | |||||
using (var timeoutCts = new CancellationTokenSource(timeout)) | |||||
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
var timeoutTask = Task.Delay(timeout, timeoutCts.Token); | |||||
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); | |||||
if (finishedTask == timeoutTask) | |||||
{ | |||||
throw new MqttCommunicationTimedOutException(); | |||||
} | |||||
if (task.IsCanceled) | |||||
await action(linkedCts.Token).ConfigureAwait(false); | |||||
} | |||||
catch (OperationCanceledException exception) | |||||
{ | |||||
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; | |||||
if (timeoutReached) | |||||
{ | { | ||||
throw new TaskCanceledException(); | |||||
throw new MqttCommunicationTimedOutException(exception); | |||||
} | } | ||||
if (task.IsFaulted) | |||||
{ | |||||
throw new MqttCommunicationException(task.Exception?.GetBaseException()); | |||||
} | |||||
} | |||||
finally | |||||
{ | |||||
timeoutCts.Cancel(); | |||||
throw; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout) | |||||
public static async Task<TResult> TimeoutAfter<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken) | |||||
{ | { | ||||
if (task == null) throw new ArgumentNullException(nameof(task)); | |||||
if (action == null) throw new ArgumentNullException(nameof(action)); | |||||
using (var timeoutCts = new CancellationTokenSource()) | |||||
using (var timeoutCts = new CancellationTokenSource(timeout)) | |||||
using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
var timeoutTask = Task.Delay(timeout, timeoutCts.Token); | |||||
var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); | |||||
if (finishedTask == timeoutTask) | |||||
{ | |||||
throw new MqttCommunicationTimedOutException(); | |||||
} | |||||
if (task.IsCanceled) | |||||
{ | |||||
throw new TaskCanceledException(); | |||||
} | |||||
if (task.IsFaulted) | |||||
return await action(linkedCts.Token).ConfigureAwait(false); | |||||
} | |||||
catch (OperationCanceledException exception) | |||||
{ | |||||
var timeoutReached = timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; | |||||
if (timeoutReached) | |||||
{ | { | ||||
throw new MqttCommunicationException(task.Exception.GetBaseException()); | |||||
throw new MqttCommunicationTimedOutException(exception); | |||||
} | } | ||||
return task.Result; | |||||
} | |||||
finally | |||||
{ | |||||
timeoutCts.Cancel(); | |||||
throw; | |||||
} | } | ||||
} | } | ||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,9 +1,9 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Threading; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using MQTTnet.Exceptions; | using MQTTnet.Exceptions; | ||||
using MQTTnet.Internal; | |||||
namespace MQTTnet.Core.Tests | namespace MQTTnet.Core.Tests | ||||
{ | { | ||||
@@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests | |||||
[TestMethod] | [TestMethod] | ||||
public async Task TimeoutAfter() | public async Task TimeoutAfter() | ||||
{ | { | ||||
await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||||
} | } | ||||
[ExpectedException(typeof(MqttCommunicationTimedOutException))] | [ExpectedException(typeof(MqttCommunicationTimedOutException))] | ||||
[TestMethod] | [TestMethod] | ||||
public async Task TimeoutAfterWithResult() | public async Task TimeoutAfterWithResult() | ||||
{ | { | ||||
await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public async Task TimeoutAfterCompleteInTime() | public async Task TimeoutAfterCompleteInTime() | ||||
{ | { | ||||
var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500)); | |||||
var result = await Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); | |||||
Assert.AreEqual(5, result); | Assert.AreEqual(5, result); | ||||
} | } | ||||
@@ -36,17 +36,17 @@ namespace MQTTnet.Core.Tests | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
await Task.Run(() => | |||||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => | |||||
{ | { | ||||
var iis = new int[0]; | var iis = new int[0]; | ||||
iis[1] = 0; | iis[1] = 0; | ||||
}).TimeoutAfter(TimeSpan.FromSeconds(1)); | |||||
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None); | |||||
Assert.Fail(); | Assert.Fail(); | ||||
} | } | ||||
catch (MqttCommunicationException e) | |||||
catch (Exception e) | |||||
{ | { | ||||
Assert.IsTrue(e.InnerException is IndexOutOfRangeException); | |||||
Assert.IsTrue(e is IndexOutOfRangeException); | |||||
} | } | ||||
} | } | ||||
@@ -55,17 +55,18 @@ namespace MQTTnet.Core.Tests | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
await Task.Run(() => | |||||
await Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => | |||||
{ | { | ||||
var iis = new int[0]; | var iis = new int[0]; | ||||
return iis[1]; | |||||
}).TimeoutAfter(TimeSpan.FromSeconds(1)); | |||||
iis[1] = 0; | |||||
return iis[0]; | |||||
}, ct), TimeSpan.FromSeconds(1), CancellationToken.None); | |||||
Assert.Fail(); | Assert.Fail(); | ||||
} | } | ||||
catch (MqttCommunicationException e) | |||||
catch (Exception e) | |||||
{ | { | ||||
Assert.IsTrue(e.InnerException is IndexOutOfRangeException); | |||||
Assert.IsTrue(e is IndexOutOfRangeException); | |||||
} | } | ||||
} | } | ||||
@@ -73,7 +74,10 @@ namespace MQTTnet.Core.Tests | |||||
public async Task TimeoutAfterMemoryUsage() | public async Task TimeoutAfterMemoryUsage() | ||||
{ | { | ||||
var tasks = Enumerable.Range(0, 100000) | var tasks = Enumerable.Range(0, 100000) | ||||
.Select(i => Task.Delay(TimeSpan.FromMilliseconds(1)).TimeoutAfter(TimeSpan.FromMinutes(1))); | |||||
.Select(i => | |||||
{ | |||||
return Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); | |||||
}); | |||||
await Task.WhenAll(tasks); | await Task.WhenAll(tasks); | ||||
AssertIsLess(3_000_000, GC.GetTotalMemory(true)); | AssertIsLess(3_000_000, GC.GetTotalMemory(true)); | ||||