@@ -1,10 +1,12 @@ | |||
using System; | |||
using System.IO; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Core.Channel; | |||
using MQTTnet.Core.Client; | |||
using MQTTnet.Core.Diagnostics; | |||
using MQTTnet.Core.Exceptions; | |||
using MQTTnet.Core.Internal; | |||
using MQTTnet.Core.Packets; | |||
using MQTTnet.Core.Serializer; | |||
@@ -24,7 +26,7 @@ namespace MQTTnet.Core.Adapter | |||
public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) | |||
{ | |||
return ExecuteWithTimeoutAsync(_channel.ConnectAsync(options), timeout); | |||
return _channel.ConnectAsync(options).TimeoutAfter(timeout); | |||
} | |||
public Task DisconnectAsync() | |||
@@ -38,7 +40,7 @@ namespace MQTTnet.Core.Adapter | |||
var writeBuffer = PacketSerializer.Serialize(packet); | |||
_sendTask = SendAsync( writeBuffer ); | |||
return ExecuteWithTimeoutAsync(_sendTask, timeout); | |||
return _sendTask.TimeoutAfter(timeout); | |||
} | |||
private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write | |||
@@ -54,7 +56,7 @@ namespace MQTTnet.Core.Adapter | |||
Tuple<MqttPacketHeader, MemoryStream> tuple; | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
tuple = await ExecuteWithTimeoutAsync(ReceiveAsync(), timeout).ConfigureAwait(false); | |||
tuple = await ReceiveAsync().TimeoutAfter(timeout).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
@@ -96,35 +98,5 @@ namespace MQTTnet.Core.Adapter | |||
return Tuple.Create(header, body); | |||
} | |||
private static async Task<TResult> ExecuteWithTimeoutAsync<TResult>(Task<TResult> task, TimeSpan timeout) | |||
{ | |||
var timeoutTask = Task.Delay(timeout); | |||
if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) | |||
{ | |||
throw new MqttCommunicationTimedOutException(); | |||
} | |||
if (task.IsFaulted) | |||
{ | |||
throw new MqttCommunicationException(task.Exception); | |||
} | |||
return task.Result; | |||
} | |||
private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) | |||
{ | |||
var timeoutTask = Task.Delay(timeout); | |||
if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) | |||
{ | |||
throw new MqttCommunicationTimedOutException(); | |||
} | |||
if (task.IsFaulted) | |||
{ | |||
throw new MqttCommunicationException(task.Exception); | |||
} | |||
} | |||
} | |||
} |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Core.Diagnostics; | |||
using MQTTnet.Core.Exceptions; | |||
using MQTTnet.Core.Internal; | |||
using MQTTnet.Core.Packets; | |||
using System.Collections.Concurrent; | |||
@@ -22,16 +23,19 @@ namespace MQTTnet.Core.Client | |||
var packetAwaiter = AddPacketAwaiter(request, responseType); | |||
DispatchPendingPackets(); | |||
var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; | |||
RemovePacketAwaiter(request, responseType); | |||
if (hasTimeout) | |||
try | |||
{ | |||
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); | |||
throw new MqttCommunicationTimedOutException(); | |||
return await packetAwaiter.Task.TimeoutAfter( timeout ); | |||
} | |||
catch ( MqttCommunicationTimedOutException ) | |||
{ | |||
MqttTrace.Warning( nameof( MqttPacketDispatcher ), "Timeout while waiting for packet." ); | |||
throw; | |||
} | |||
finally | |||
{ | |||
RemovePacketAwaiter(request, responseType); | |||
} | |||
return packetAwaiter.Task.Result; | |||
} | |||
public void Dispatch(MqttBasePacket packet) | |||
@@ -0,0 +1,55 @@ | |||
using System; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Core.Exceptions; | |||
namespace MQTTnet.Core.Internal | |||
{ | |||
public static class TaskExtensions | |||
{ | |||
public static Task TimeoutAfter( this Task task, TimeSpan timeout ) | |||
{ | |||
return TimeoutAfter( task.ContinueWith( t => 0 ), timeout ); | |||
} | |||
public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout) | |||
{ | |||
using (var cancellationTokenSource = new CancellationTokenSource()) | |||
{ | |||
var tcs = new TaskCompletionSource<TResult>(); | |||
cancellationTokenSource.Token.Register(() => | |||
{ | |||
tcs.TrySetCanceled(); | |||
} ); | |||
try | |||
{ | |||
cancellationTokenSource.CancelAfter(timeout); | |||
task.ContinueWith( t => | |||
{ | |||
if (t.IsFaulted) | |||
{ | |||
tcs.TrySetException(t.Exception); | |||
} | |||
if (t.IsCompleted) | |||
{ | |||
tcs.TrySetResult(t.Result); | |||
} | |||
}, cancellationTokenSource.Token ); | |||
return await tcs.Task; | |||
} | |||
catch (TaskCanceledException) | |||
{ | |||
throw new MqttCommunicationTimedOutException(); | |||
} | |||
catch (Exception e) | |||
{ | |||
throw new MqttCommunicationException(e); | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,33 @@ | |||
using System; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Core.Exceptions; | |||
using MQTTnet.Core.Internal; | |||
namespace MQTTnet.Core.Tests | |||
{ | |||
[TestClass] | |||
public class ExtensionTests | |||
{ | |||
[ExpectedException(typeof( MqttCommunicationTimedOutException ) )] | |||
[TestMethod] | |||
public async Task TestTimeoutAfter() | |||
{ | |||
await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||
} | |||
[ExpectedException(typeof( MqttCommunicationTimedOutException))] | |||
[TestMethod] | |||
public async Task TestTimeoutAfterWithResult() | |||
{ | |||
await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||
} | |||
[TestMethod] | |||
public async Task TestTimeoutAfterCompleteInTime() | |||
{ | |||
var result = await Task.Delay( TimeSpan.FromMilliseconds( 100 ) ).ContinueWith( t => 5 ).TimeoutAfter( TimeSpan.FromMilliseconds( 500 ) ); | |||
Assert.AreEqual( 5, result ); | |||
} | |||
} | |||
} |
@@ -86,6 +86,7 @@ | |||
<ItemGroup> | |||
<Compile Include="ByteReaderTests.cs" /> | |||
<Compile Include="ByteWriterTests.cs" /> | |||
<Compile Include="ExtensionTests.cs" /> | |||
<Compile Include="MqttPacketSerializerTests.cs" /> | |||
<Compile Include="MqttServerTests.cs" /> | |||
<Compile Include="MqttSubscriptionsManagerTests.cs" /> | |||