diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index e42b5e5..63740b2 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -1,10 +1,12 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using MQTTnet.Core.Serializer; @@ -24,7 +26,7 @@ namespace MQTTnet.Core.Adapter public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) { - return ExecuteWithTimeoutAsync(_channel.ConnectAsync(options), timeout); + return _channel.ConnectAsync(options).TimeoutAfter(timeout); } public Task DisconnectAsync() @@ -38,7 +40,7 @@ namespace MQTTnet.Core.Adapter var writeBuffer = PacketSerializer.Serialize(packet); _sendTask = SendAsync( writeBuffer ); - return ExecuteWithTimeoutAsync(_sendTask, timeout); + return _sendTask.TimeoutAfter(timeout); } private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write @@ -54,7 +56,7 @@ namespace MQTTnet.Core.Adapter Tuple tuple; if (timeout > TimeSpan.Zero) { - tuple = await ExecuteWithTimeoutAsync(ReceiveAsync(), timeout).ConfigureAwait(false); + tuple = await ReceiveAsync().TimeoutAfter(timeout).ConfigureAwait(false); } else { @@ -96,35 +98,5 @@ namespace MQTTnet.Core.Adapter return Tuple.Create(header, body); } - - private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) - { - var timeoutTask = Task.Delay(timeout); - if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception); - } - - return task.Result; - } - - private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) - { - var timeoutTask = Task.Delay(timeout); - if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception); - } - } } } \ No newline at end of file diff --git a/MQTTnet.Core/Client/MqttPacketDispatcher.cs b/MQTTnet.Core/Client/MqttPacketDispatcher.cs index f057b97..6d369b3 100644 --- a/MQTTnet.Core/Client/MqttPacketDispatcher.cs +++ b/MQTTnet.Core/Client/MqttPacketDispatcher.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Threading.Tasks; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using System.Collections.Concurrent; @@ -22,16 +23,19 @@ namespace MQTTnet.Core.Client var packetAwaiter = AddPacketAwaiter(request, responseType); DispatchPendingPackets(); - var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; - RemovePacketAwaiter(request, responseType); - - if (hasTimeout) + try { - MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); - throw new MqttCommunicationTimedOutException(); + return await packetAwaiter.Task.TimeoutAfter( timeout ); + } + catch ( MqttCommunicationTimedOutException ) + { + MqttTrace.Warning( nameof( MqttPacketDispatcher ), "Timeout while waiting for packet." ); + throw; + } + finally + { + RemovePacketAwaiter(request, responseType); } - - return packetAwaiter.Task.Result; } public void Dispatch(MqttBasePacket packet) diff --git a/MQTTnet.Core/Internal/TaskExtensions.cs b/MQTTnet.Core/Internal/TaskExtensions.cs new file mode 100644 index 0000000..eb8f485 --- /dev/null +++ b/MQTTnet.Core/Internal/TaskExtensions.cs @@ -0,0 +1,55 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Core.Exceptions; + +namespace MQTTnet.Core.Internal +{ + public static class TaskExtensions + { + public static Task TimeoutAfter( this Task task, TimeSpan timeout ) + { + return TimeoutAfter( task.ContinueWith( t => 0 ), timeout ); + } + + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + using (var cancellationTokenSource = new CancellationTokenSource()) + { + var tcs = new TaskCompletionSource(); + + cancellationTokenSource.Token.Register(() => + { + tcs.TrySetCanceled(); + } ); + + try + { + cancellationTokenSource.CancelAfter(timeout); + task.ContinueWith( t => + { + if (t.IsFaulted) + { + tcs.TrySetException(t.Exception); + } + + if (t.IsCompleted) + { + tcs.TrySetResult(t.Result); + } + }, cancellationTokenSource.Token ); + + return await tcs.Task; + } + catch (TaskCanceledException) + { + throw new MqttCommunicationTimedOutException(); + } + catch (Exception e) + { + throw new MqttCommunicationException(e); + } + } + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs new file mode 100644 index 0000000..1a3cd29 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -0,0 +1,33 @@ +using System; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class ExtensionTests + { + [ExpectedException(typeof( MqttCommunicationTimedOutException ) )] + [TestMethod] + public async Task TestTimeoutAfter() + { + await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + } + + [ExpectedException(typeof( MqttCommunicationTimedOutException))] + [TestMethod] + public async Task TestTimeoutAfterWithResult() + { + await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + } + + [TestMethod] + public async Task TestTimeoutAfterCompleteInTime() + { + var result = await Task.Delay( TimeSpan.FromMilliseconds( 100 ) ).ContinueWith( t => 5 ).TimeoutAfter( TimeSpan.FromMilliseconds( 500 ) ); + Assert.AreEqual( 5, result ); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj b/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj index 88bf8f2..2e20398 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj +++ b/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj @@ -86,6 +86,7 @@ +