From b5b87f5272559420fa1c3bf494137a774f3e2b12 Mon Sep 17 00:00:00 2001 From: Eggers Jan Date: Wed, 20 Sep 2017 09:33:41 +0200 Subject: [PATCH] improved timeout after memory usage --- MQTTnet.Core/Internal/TaskExtensions.cs | 121 +++++++++------------ Tests/MQTTnet.Core.Tests/ExtensionTests.cs | 19 ++++ 2 files changed, 68 insertions(+), 72 deletions(-) diff --git a/MQTTnet.Core/Internal/TaskExtensions.cs b/MQTTnet.Core/Internal/TaskExtensions.cs index f9428ab..6db8067 100644 --- a/MQTTnet.Core/Internal/TaskExtensions.cs +++ b/MQTTnet.Core/Internal/TaskExtensions.cs @@ -7,91 +7,68 @@ namespace MQTTnet.Core.Internal { public static class TaskExtensions { - public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + public static async Task TimeoutAfter( this Task task, TimeSpan timeout ) { - var timeoutTask = Task.Delay(timeout); - var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - - if (finishedTask == timeoutTask) + using ( var cancellationTokenSource = new CancellationTokenSource() ) { - throw new MqttCommunicationTimedOutException(); - } + try + { + var timeoutTask = Task.Delay(timeout, cancellationTokenSource.Token); + var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - if (task.IsCanceled) - { - throw new TaskCanceledException(); - } + if ( finishedTask == timeoutTask ) + { + throw new MqttCommunicationTimedOutException(); + } - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception.GetBaseException()); - } + if ( task.IsCanceled ) + { + throw new TaskCanceledException(); + } - ////return TimeoutAfter(task.ContinueWith(t => 0), timeout); + if ( task.IsFaulted ) + { + throw new MqttCommunicationException( task.Exception.GetBaseException() ); + } + } + finally + { + cancellationTokenSource.Cancel(); + } + } } - public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + public static async Task TimeoutAfter( this Task task, TimeSpan timeout ) { - var timeoutTask = Task.Delay(timeout); - var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - - if (finishedTask == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } - - if (task.IsCanceled) + using ( var cancellationTokenSource = new CancellationTokenSource() ) { - throw new TaskCanceledException(); - } - - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception.GetBaseException()); - } + try + { + var timeoutTask = Task.Delay(timeout, cancellationTokenSource.Token); + var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); - return task.Result; + if ( finishedTask == timeoutTask ) + { + throw new MqttCommunicationTimedOutException(); + } - //// using (var cancellationTokenSource = new CancellationTokenSource()) - //// { - //// var tcs = new TaskCompletionSource(); + if ( task.IsCanceled ) + { + throw new TaskCanceledException(); + } - //// cancellationTokenSource.Token.Register(() => - //// { - //// tcs.TrySetCanceled(); - //// }); + if ( task.IsFaulted ) + { + throw new MqttCommunicationException( task.Exception.GetBaseException() ); + } - //// try - //// { - ////#pragma warning disable 4014 - //// task.ContinueWith(t => - ////#pragma warning restore 4014 - //// { - //// if (t.IsFaulted) - //// { - //// tcs.TrySetException(t.Exception); - //// } - - //// if (t.IsCompleted) - //// { - //// tcs.TrySetResult(t.Result); - //// } - - //// return t.Result; - //// }, cancellationTokenSource.Token).ConfigureAwait(false); - - //// cancellationTokenSource.CancelAfter(timeout); - //// return await tcs.Task.ConfigureAwait(false); - //// } - //// catch (TaskCanceledException) - //// { - //// throw new MqttCommunicationTimedOutException(); - //// } - //// catch (Exception e) - //// { - //// throw new MqttCommunicationException(e); - //// } - //// } + return task.Result; + } + finally + { + cancellationTokenSource.Cancel(); + } + } } } } diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs index b827326..e0d2152 100644 --- a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Core.Exceptions; @@ -67,5 +68,23 @@ namespace MQTTnet.Core.Tests Assert.IsTrue(e.InnerException is IndexOutOfRangeException); } } + + [TestMethod] + public async Task TimeoutAfterMemoryUsage() + { + var tasks = Enumerable.Range(0, 100000) + .Select(i => Task.Delay(TimeSpan.FromMilliseconds(1)).TimeoutAfter(TimeSpan.FromMinutes(1))); + + await Task.WhenAll( tasks ); + AssertIsLess( 3_000_000, GC.GetTotalMemory( true ) ); + } + + private void AssertIsLess( long bound, long actual ) + { + if ( bound < actual ) + { + Assert.Fail( $"value must be less than {bound:N0} but is {actual:N0}" ); + } + } } }