|
@@ -4,6 +4,7 @@ using System.Threading.Tasks; |
|
|
using MQTTnet.Core.Diagnostics; |
|
|
using MQTTnet.Core.Diagnostics; |
|
|
using MQTTnet.Core.Exceptions; |
|
|
using MQTTnet.Core.Exceptions; |
|
|
using MQTTnet.Core.Packets; |
|
|
using MQTTnet.Core.Packets; |
|
|
|
|
|
using System.Collections.Concurrent; |
|
|
|
|
|
|
|
|
namespace MQTTnet.Core.Client |
|
|
namespace MQTTnet.Core.Client |
|
|
{ |
|
|
{ |
|
@@ -11,17 +12,18 @@ namespace MQTTnet.Core.Client |
|
|
{ |
|
|
{ |
|
|
private readonly object _syncRoot = new object(); |
|
|
private readonly object _syncRoot = new object(); |
|
|
private readonly HashSet<MqttBasePacket> _receivedPackets = new HashSet<MqttBasePacket>(); |
|
|
private readonly HashSet<MqttBasePacket> _receivedPackets = new HashSet<MqttBasePacket>(); |
|
|
private readonly List<MqttPacketAwaiter> _packetAwaiters = new List<MqttPacketAwaiter>(); |
|
|
|
|
|
|
|
|
private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>(); |
|
|
|
|
|
private readonly ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>> _packetByIdentifier = new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>(); |
|
|
|
|
|
|
|
|
public async Task<MqttBasePacket> WaitForPacketAsync(Func<MqttBasePacket, bool> selector, TimeSpan timeout) |
|
|
|
|
|
|
|
|
public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) |
|
|
{ |
|
|
{ |
|
|
if (selector == null) throw new ArgumentNullException(nameof(selector)); |
|
|
|
|
|
|
|
|
if (request == null) throw new ArgumentNullException(nameof(request)); |
|
|
|
|
|
|
|
|
var packetAwaiter = AddPacketAwaiter(selector); |
|
|
|
|
|
|
|
|
var packetAwaiter = AddPacketAwaiter(request, responseType); |
|
|
DispatchPendingPackets(); |
|
|
DispatchPendingPackets(); |
|
|
|
|
|
|
|
|
var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; |
|
|
var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; |
|
|
RemovePacketAwaiter(packetAwaiter); |
|
|
|
|
|
|
|
|
RemovePacketAwaiter(request, responseType); |
|
|
|
|
|
|
|
|
if (hasTimeout) |
|
|
if (hasTimeout) |
|
|
{ |
|
|
{ |
|
@@ -37,15 +39,20 @@ namespace MQTTnet.Core.Client |
|
|
if (packet == null) throw new ArgumentNullException(nameof(packet)); |
|
|
if (packet == null) throw new ArgumentNullException(nameof(packet)); |
|
|
|
|
|
|
|
|
var packetDispatched = false; |
|
|
var packetDispatched = false; |
|
|
foreach (var packetAwaiter in GetPacketAwaiters()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (packet is IMqttPacketWithIdentifier withIdentifier) |
|
|
{ |
|
|
{ |
|
|
if (packetAwaiter.PacketSelector(packet)) |
|
|
|
|
|
|
|
|
if (_packetByIdentifier.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) |
|
|
{ |
|
|
{ |
|
|
packetAwaiter.TrySetResult(packet); |
|
|
packetAwaiter.TrySetResult(packet); |
|
|
packetDispatched = true; |
|
|
packetDispatched = true; |
|
|
break; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
else if (_packetByResponseType.TryRemove(packet.GetType(), out var tcs) ) |
|
|
|
|
|
{ |
|
|
|
|
|
tcs.SetResult(packet); |
|
|
|
|
|
packetDispatched = true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
lock (_syncRoot) |
|
|
lock (_syncRoot) |
|
|
{ |
|
|
{ |
|
@@ -64,34 +71,36 @@ namespace MQTTnet.Core.Client |
|
|
{ |
|
|
{ |
|
|
lock (_syncRoot) |
|
|
lock (_syncRoot) |
|
|
{ |
|
|
{ |
|
|
_packetAwaiters.Clear(); |
|
|
|
|
|
_receivedPackets.Clear(); |
|
|
_receivedPackets.Clear(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_packetByIdentifier.Clear(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private List<MqttPacketAwaiter> GetPacketAwaiters() |
|
|
|
|
|
|
|
|
private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(MqttBasePacket request, Type responseType) |
|
|
{ |
|
|
{ |
|
|
lock (_syncRoot) |
|
|
|
|
|
|
|
|
var tcs = new TaskCompletionSource<MqttBasePacket>(); |
|
|
|
|
|
if (request is IMqttPacketWithIdentifier withIdent) |
|
|
{ |
|
|
{ |
|
|
return new List<MqttPacketAwaiter>(_packetAwaiters); |
|
|
|
|
|
|
|
|
_packetByIdentifier[withIdent.PacketIdentifier] = tcs; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private MqttPacketAwaiter AddPacketAwaiter(Func<MqttBasePacket, bool> selector) |
|
|
|
|
|
{ |
|
|
|
|
|
lock (_syncRoot) |
|
|
|
|
|
|
|
|
else |
|
|
{ |
|
|
{ |
|
|
var packetAwaiter = new MqttPacketAwaiter(selector); |
|
|
|
|
|
_packetAwaiters.Add(packetAwaiter); |
|
|
|
|
|
return packetAwaiter; |
|
|
|
|
|
|
|
|
_packetByResponseType[responseType] = tcs; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return tcs; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private void RemovePacketAwaiter(MqttPacketAwaiter packetAwaiter) |
|
|
|
|
|
|
|
|
private void RemovePacketAwaiter(MqttBasePacket request, Type responseType) |
|
|
{ |
|
|
{ |
|
|
lock (_syncRoot) |
|
|
|
|
|
|
|
|
if (request is IMqttPacketWithIdentifier withIdent) |
|
|
|
|
|
{ |
|
|
|
|
|
_packetByIdentifier.TryRemove(withIdent.PacketIdentifier, out var tcs); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
{ |
|
|
{ |
|
|
_packetAwaiters.Remove(packetAwaiter); |
|
|
|
|
|
|
|
|
_packetByResponseType.TryRemove(responseType, out var tcs); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|