From 9b8af2455f684a31197abe2ad2bbf0c0afa3fe65 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Tue, 19 Sep 2017 23:08:18 +0200 Subject: [PATCH] Fix issues with "TimeoutAfter" --- MQTTnet.Core/Internal/TaskExtensions.cs | 116 ++++++++++++++++-------- 1 file changed, 78 insertions(+), 38 deletions(-) diff --git a/MQTTnet.Core/Internal/TaskExtensions.cs b/MQTTnet.Core/Internal/TaskExtensions.cs index 82a88ab..38513a0 100644 --- a/MQTTnet.Core/Internal/TaskExtensions.cs +++ b/MQTTnet.Core/Internal/TaskExtensions.cs @@ -7,51 +7,91 @@ namespace MQTTnet.Core.Internal { public static class TaskExtensions { - public static Task TimeoutAfter(this Task task, TimeSpan timeout) + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) { - return TimeoutAfter(task.ContinueWith(t => 0), timeout); + var timeoutTask = Task.Delay(timeout); + var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); + + if (finishedTask == timeoutTask || task.IsCanceled) + { + throw new MqttCommunicationTimedOutException(); + } + + if (task.IsCanceled) + { + throw new TaskCanceledException(); + } + + if (task.IsFaulted) + { + throw new MqttCommunicationException(task.Exception); + } + + ////return TimeoutAfter(task.ContinueWith(t => 0), timeout); } public static async Task TimeoutAfter(this Task task, TimeSpan timeout) { - using (var cancellationTokenSource = new CancellationTokenSource()) + var timeoutTask = Task.Delay(timeout); + var finishedTask = await Task.WhenAny(timeoutTask, task).ConfigureAwait(false); + + if (finishedTask == timeoutTask) { - var tcs = new TaskCompletionSource(); - - cancellationTokenSource.Token.Register(() => - { - tcs.TrySetCanceled(); - }); - - try - { -#pragma warning disable 4014 - task.ContinueWith(t => -#pragma warning restore 4014 - { - if (t.IsFaulted) - { - tcs.TrySetException(t.Exception); - } - - if (t.IsCompleted) - { - tcs.TrySetResult(t.Result); - } - }, cancellationTokenSource.Token); - - cancellationTokenSource.CancelAfter(timeout); - return await tcs.Task; - } - catch (TaskCanceledException) - { - throw new MqttCommunicationTimedOutException(); - } - catch (Exception e) - { - throw new MqttCommunicationException(e); - } + throw new MqttCommunicationTimedOutException(); } + + if (task.IsCanceled) + { + throw new TaskCanceledException(); + } + + if (task.IsFaulted) + { + throw new MqttCommunicationException(task.Exception); + } + + return task.Result; + + //// using (var cancellationTokenSource = new CancellationTokenSource()) + //// { + //// var tcs = new TaskCompletionSource(); + + //// cancellationTokenSource.Token.Register(() => + //// { + //// tcs.TrySetCanceled(); + //// }); + + //// try + //// { + ////#pragma warning disable 4014 + //// task.ContinueWith(t => + ////#pragma warning restore 4014 + //// { + //// if (t.IsFaulted) + //// { + //// tcs.TrySetException(t.Exception); + //// } + + //// if (t.IsCompleted) + //// { + //// tcs.TrySetResult(t.Result); + //// } + + //// return t.Result; + //// }, cancellationTokenSource.Token).ConfigureAwait(false); + + //// cancellationTokenSource.CancelAfter(timeout); + //// return await tcs.Task.ConfigureAwait(false); + //// } + //// catch (TaskCanceledException) + //// { + //// throw new MqttCommunicationTimedOutException(); + //// } + //// catch (Exception e) + //// { + //// throw new MqttCommunicationException(e); + //// } + //// } } } }