From 090e59e99afdf1e278f0cab4a71639a98a2a5a8b Mon Sep 17 00:00:00 2001 From: Eggers Jan Date: Mon, 11 Sep 2017 17:50:50 +0200 Subject: [PATCH] unfifed timeout handling and fixed memory usage due to Task.Delay tasks for each send package are present for the duration of the timeout resulting in memory usage. new approach uses cancellationtoken that will be cleaned up directly if operation completes before timeout --- .../MqttChannelCommunicationAdapter.cs | 38 ++----------- MQTTnet.Core/Client/MqttPacketDispatcher.cs | 20 ++++--- MQTTnet.Core/Internal/TaskExtensions.cs | 55 +++++++++++++++++++ Tests/MQTTnet.Core.Tests/ExtensionTests.cs | 33 +++++++++++ .../MQTTnet.Core.Tests.csproj | 1 + 5 files changed, 106 insertions(+), 41 deletions(-) create mode 100644 MQTTnet.Core/Internal/TaskExtensions.cs create mode 100644 Tests/MQTTnet.Core.Tests/ExtensionTests.cs diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index e42b5e5..63740b2 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -1,10 +1,12 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using MQTTnet.Core.Serializer; @@ -24,7 +26,7 @@ namespace MQTTnet.Core.Adapter public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) { - return ExecuteWithTimeoutAsync(_channel.ConnectAsync(options), timeout); + return _channel.ConnectAsync(options).TimeoutAfter(timeout); } public Task DisconnectAsync() @@ -38,7 +40,7 @@ namespace MQTTnet.Core.Adapter var writeBuffer = PacketSerializer.Serialize(packet); _sendTask = SendAsync( writeBuffer ); - return ExecuteWithTimeoutAsync(_sendTask, timeout); + return _sendTask.TimeoutAfter(timeout); } private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write @@ -54,7 +56,7 @@ namespace MQTTnet.Core.Adapter Tuple tuple; if (timeout > TimeSpan.Zero) { - tuple = await ExecuteWithTimeoutAsync(ReceiveAsync(), timeout).ConfigureAwait(false); + tuple = await ReceiveAsync().TimeoutAfter(timeout).ConfigureAwait(false); } else { @@ -96,35 +98,5 @@ namespace MQTTnet.Core.Adapter return Tuple.Create(header, body); } - - private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) - { - var timeoutTask = Task.Delay(timeout); - if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception); - } - - return task.Result; - } - - private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) - { - var timeoutTask = Task.Delay(timeout); - if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception); - } - } } } \ No newline at end of file diff --git a/MQTTnet.Core/Client/MqttPacketDispatcher.cs b/MQTTnet.Core/Client/MqttPacketDispatcher.cs index f057b97..6d369b3 100644 --- a/MQTTnet.Core/Client/MqttPacketDispatcher.cs +++ b/MQTTnet.Core/Client/MqttPacketDispatcher.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Threading.Tasks; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using System.Collections.Concurrent; @@ -22,16 +23,19 @@ namespace MQTTnet.Core.Client var packetAwaiter = AddPacketAwaiter(request, responseType); DispatchPendingPackets(); - var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; - RemovePacketAwaiter(request, responseType); - - if (hasTimeout) + try { - MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); - throw new MqttCommunicationTimedOutException(); + return await packetAwaiter.Task.TimeoutAfter( timeout ); + } + catch ( MqttCommunicationTimedOutException ) + { + MqttTrace.Warning( nameof( MqttPacketDispatcher ), "Timeout while waiting for packet." ); + throw; + } + finally + { + RemovePacketAwaiter(request, responseType); } - - return packetAwaiter.Task.Result; } public void Dispatch(MqttBasePacket packet) diff --git a/MQTTnet.Core/Internal/TaskExtensions.cs b/MQTTnet.Core/Internal/TaskExtensions.cs new file mode 100644 index 0000000..eb8f485 --- /dev/null +++ b/MQTTnet.Core/Internal/TaskExtensions.cs @@ -0,0 +1,55 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Core.Exceptions; + +namespace MQTTnet.Core.Internal +{ + public static class TaskExtensions + { + public static Task TimeoutAfter( this Task task, TimeSpan timeout ) + { + return TimeoutAfter( task.ContinueWith( t => 0 ), timeout ); + } + + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + using (var cancellationTokenSource = new CancellationTokenSource()) + { + var tcs = new TaskCompletionSource(); + + cancellationTokenSource.Token.Register(() => + { + tcs.TrySetCanceled(); + } ); + + try + { + cancellationTokenSource.CancelAfter(timeout); + task.ContinueWith( t => + { + if (t.IsFaulted) + { + tcs.TrySetException(t.Exception); + } + + if (t.IsCompleted) + { + tcs.TrySetResult(t.Result); + } + }, cancellationTokenSource.Token ); + + return await tcs.Task; + } + catch (TaskCanceledException) + { + throw new MqttCommunicationTimedOutException(); + } + catch (Exception e) + { + throw new MqttCommunicationException(e); + } + } + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs new file mode 100644 index 0000000..1a3cd29 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -0,0 +1,33 @@ +using System; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class ExtensionTests + { + [ExpectedException(typeof( MqttCommunicationTimedOutException ) )] + [TestMethod] + public async Task TestTimeoutAfter() + { + await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + } + + [ExpectedException(typeof( MqttCommunicationTimedOutException))] + [TestMethod] + public async Task TestTimeoutAfterWithResult() + { + await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + } + + [TestMethod] + public async Task TestTimeoutAfterCompleteInTime() + { + var result = await Task.Delay( TimeSpan.FromMilliseconds( 100 ) ).ContinueWith( t => 5 ).TimeoutAfter( TimeSpan.FromMilliseconds( 500 ) ); + Assert.AreEqual( 5, result ); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj b/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj index 88bf8f2..2e20398 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj +++ b/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj @@ -86,6 +86,7 @@ +