@@ -1,10 +1,12 @@ | |||||
using System; | using System; | ||||
using System.IO; | using System.IO; | ||||
using System.Threading; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using MQTTnet.Core.Channel; | using MQTTnet.Core.Channel; | ||||
using MQTTnet.Core.Client; | using MQTTnet.Core.Client; | ||||
using MQTTnet.Core.Diagnostics; | using MQTTnet.Core.Diagnostics; | ||||
using MQTTnet.Core.Exceptions; | using MQTTnet.Core.Exceptions; | ||||
using MQTTnet.Core.Internal; | |||||
using MQTTnet.Core.Packets; | using MQTTnet.Core.Packets; | ||||
using MQTTnet.Core.Serializer; | using MQTTnet.Core.Serializer; | ||||
@@ -24,7 +26,7 @@ namespace MQTTnet.Core.Adapter | |||||
public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) | public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) | ||||
{ | { | ||||
return ExecuteWithTimeoutAsync(_channel.ConnectAsync(options), timeout); | |||||
return _channel.ConnectAsync(options).TimeoutAfter(timeout); | |||||
} | } | ||||
public Task DisconnectAsync() | public Task DisconnectAsync() | ||||
@@ -38,7 +40,7 @@ namespace MQTTnet.Core.Adapter | |||||
var writeBuffer = PacketSerializer.Serialize(packet); | var writeBuffer = PacketSerializer.Serialize(packet); | ||||
_sendTask = SendAsync( writeBuffer ); | _sendTask = SendAsync( writeBuffer ); | ||||
return ExecuteWithTimeoutAsync(_sendTask, timeout); | |||||
return _sendTask.TimeoutAfter(timeout); | |||||
} | } | ||||
private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write | private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write | ||||
@@ -54,7 +56,7 @@ namespace MQTTnet.Core.Adapter | |||||
Tuple<MqttPacketHeader, MemoryStream> tuple; | Tuple<MqttPacketHeader, MemoryStream> tuple; | ||||
if (timeout > TimeSpan.Zero) | if (timeout > TimeSpan.Zero) | ||||
{ | { | ||||
tuple = await ExecuteWithTimeoutAsync(ReceiveAsync(), timeout).ConfigureAwait(false); | |||||
tuple = await ReceiveAsync().TimeoutAfter(timeout).ConfigureAwait(false); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -96,35 +98,5 @@ namespace MQTTnet.Core.Adapter | |||||
return Tuple.Create(header, body); | return Tuple.Create(header, body); | ||||
} | } | ||||
private static async Task<TResult> ExecuteWithTimeoutAsync<TResult>(Task<TResult> task, TimeSpan timeout) | |||||
{ | |||||
var timeoutTask = Task.Delay(timeout); | |||||
if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) | |||||
{ | |||||
throw new MqttCommunicationTimedOutException(); | |||||
} | |||||
if (task.IsFaulted) | |||||
{ | |||||
throw new MqttCommunicationException(task.Exception); | |||||
} | |||||
return task.Result; | |||||
} | |||||
private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) | |||||
{ | |||||
var timeoutTask = Task.Delay(timeout); | |||||
if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) | |||||
{ | |||||
throw new MqttCommunicationTimedOutException(); | |||||
} | |||||
if (task.IsFaulted) | |||||
{ | |||||
throw new MqttCommunicationException(task.Exception); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using MQTTnet.Core.Diagnostics; | using MQTTnet.Core.Diagnostics; | ||||
using MQTTnet.Core.Exceptions; | using MQTTnet.Core.Exceptions; | ||||
using MQTTnet.Core.Internal; | |||||
using MQTTnet.Core.Packets; | using MQTTnet.Core.Packets; | ||||
using System.Collections.Concurrent; | using System.Collections.Concurrent; | ||||
@@ -22,16 +23,19 @@ namespace MQTTnet.Core.Client | |||||
var packetAwaiter = AddPacketAwaiter(request, responseType); | var packetAwaiter = AddPacketAwaiter(request, responseType); | ||||
DispatchPendingPackets(); | DispatchPendingPackets(); | ||||
var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; | |||||
RemovePacketAwaiter(request, responseType); | |||||
if (hasTimeout) | |||||
try | |||||
{ | { | ||||
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); | |||||
throw new MqttCommunicationTimedOutException(); | |||||
return await packetAwaiter.Task.TimeoutAfter( timeout ); | |||||
} | |||||
catch ( MqttCommunicationTimedOutException ) | |||||
{ | |||||
MqttTrace.Warning( nameof( MqttPacketDispatcher ), "Timeout while waiting for packet." ); | |||||
throw; | |||||
} | |||||
finally | |||||
{ | |||||
RemovePacketAwaiter(request, responseType); | |||||
} | } | ||||
return packetAwaiter.Task.Result; | |||||
} | } | ||||
public void Dispatch(MqttBasePacket packet) | public void Dispatch(MqttBasePacket packet) | ||||
@@ -0,0 +1,55 @@ | |||||
using System; | |||||
using System.Threading; | |||||
using System.Threading.Tasks; | |||||
using MQTTnet.Core.Exceptions; | |||||
namespace MQTTnet.Core.Internal | |||||
{ | |||||
public static class TaskExtensions | |||||
{ | |||||
public static Task TimeoutAfter( this Task task, TimeSpan timeout ) | |||||
{ | |||||
return TimeoutAfter( task.ContinueWith( t => 0 ), timeout ); | |||||
} | |||||
public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout) | |||||
{ | |||||
using (var cancellationTokenSource = new CancellationTokenSource()) | |||||
{ | |||||
var tcs = new TaskCompletionSource<TResult>(); | |||||
cancellationTokenSource.Token.Register(() => | |||||
{ | |||||
tcs.TrySetCanceled(); | |||||
} ); | |||||
try | |||||
{ | |||||
cancellationTokenSource.CancelAfter(timeout); | |||||
task.ContinueWith( t => | |||||
{ | |||||
if (t.IsFaulted) | |||||
{ | |||||
tcs.TrySetException(t.Exception); | |||||
} | |||||
if (t.IsCompleted) | |||||
{ | |||||
tcs.TrySetResult(t.Result); | |||||
} | |||||
}, cancellationTokenSource.Token ); | |||||
return await tcs.Task; | |||||
} | |||||
catch (TaskCanceledException) | |||||
{ | |||||
throw new MqttCommunicationTimedOutException(); | |||||
} | |||||
catch (Exception e) | |||||
{ | |||||
throw new MqttCommunicationException(e); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,33 @@ | |||||
using System; | |||||
using System.Threading.Tasks; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using MQTTnet.Core.Exceptions; | |||||
using MQTTnet.Core.Internal; | |||||
namespace MQTTnet.Core.Tests | |||||
{ | |||||
[TestClass] | |||||
public class ExtensionTests | |||||
{ | |||||
[ExpectedException(typeof( MqttCommunicationTimedOutException ) )] | |||||
[TestMethod] | |||||
public async Task TestTimeoutAfter() | |||||
{ | |||||
await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||||
} | |||||
[ExpectedException(typeof( MqttCommunicationTimedOutException))] | |||||
[TestMethod] | |||||
public async Task TestTimeoutAfterWithResult() | |||||
{ | |||||
await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); | |||||
} | |||||
[TestMethod] | |||||
public async Task TestTimeoutAfterCompleteInTime() | |||||
{ | |||||
var result = await Task.Delay( TimeSpan.FromMilliseconds( 100 ) ).ContinueWith( t => 5 ).TimeoutAfter( TimeSpan.FromMilliseconds( 500 ) ); | |||||
Assert.AreEqual( 5, result ); | |||||
} | |||||
} | |||||
} |
@@ -86,6 +86,7 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<Compile Include="ByteReaderTests.cs" /> | <Compile Include="ByteReaderTests.cs" /> | ||||
<Compile Include="ByteWriterTests.cs" /> | <Compile Include="ByteWriterTests.cs" /> | ||||
<Compile Include="ExtensionTests.cs" /> | |||||
<Compile Include="MqttPacketSerializerTests.cs" /> | <Compile Include="MqttPacketSerializerTests.cs" /> | ||||
<Compile Include="MqttServerTests.cs" /> | <Compile Include="MqttServerTests.cs" /> | ||||
<Compile Include="MqttSubscriptionsManagerTests.cs" /> | <Compile Include="MqttSubscriptionsManagerTests.cs" /> | ||||