|
|
@@ -14,7 +14,7 @@ namespace MQTTnet.Core.Client |
|
|
|
private readonly object _syncRoot = new object(); |
|
|
|
private readonly HashSet<MqttBasePacket> _receivedPackets = new HashSet<MqttBasePacket>(); |
|
|
|
private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>(); |
|
|
|
private readonly ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>> _packetByIdentifier = new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>(); |
|
|
|
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>>(); |
|
|
|
|
|
|
|
public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) |
|
|
|
{ |
|
|
@@ -42,17 +42,21 @@ namespace MQTTnet.Core.Client |
|
|
|
{ |
|
|
|
if (packet == null) throw new ArgumentNullException(nameof(packet)); |
|
|
|
|
|
|
|
var type = packet.GetType(); |
|
|
|
var packetDispatched = false; |
|
|
|
|
|
|
|
if (packet is IMqttPacketWithIdentifier withIdentifier) |
|
|
|
{ |
|
|
|
if (_packetByIdentifier.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) |
|
|
|
if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) |
|
|
|
{ |
|
|
|
tcs.TrySetResult(packet); |
|
|
|
packetDispatched = true; |
|
|
|
if (byid.TryRemove( withIdentifier.PacketIdentifier, out var tcs)) |
|
|
|
{ |
|
|
|
tcs.TrySetResult( packet ); |
|
|
|
packetDispatched = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
else if (_packetByResponseType.TryRemove(packet.GetType(), out var tcs)) |
|
|
|
else if (_packetByResponseType.TryRemove(type, out var tcs)) |
|
|
|
{ |
|
|
|
tcs.TrySetResult(packet); |
|
|
|
packetDispatched = true; |
|
|
@@ -78,7 +82,7 @@ namespace MQTTnet.Core.Client |
|
|
|
_receivedPackets.Clear(); |
|
|
|
} |
|
|
|
|
|
|
|
_packetByIdentifier.Clear(); |
|
|
|
_packetByResponseTypeAndIdentifier.Clear(); |
|
|
|
} |
|
|
|
|
|
|
|
private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(MqttBasePacket request, Type responseType) |
|
|
@@ -86,7 +90,8 @@ namespace MQTTnet.Core.Client |
|
|
|
var tcs = new TaskCompletionSource<MqttBasePacket>(); |
|
|
|
if (request is IMqttPacketWithIdentifier withIdent) |
|
|
|
{ |
|
|
|
_packetByIdentifier[withIdent.PacketIdentifier] = tcs; |
|
|
|
var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>()); |
|
|
|
byId[withIdent.PacketIdentifier] = tcs; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
@@ -100,7 +105,8 @@ namespace MQTTnet.Core.Client |
|
|
|
{ |
|
|
|
if (request is IMqttPacketWithIdentifier withIdent) |
|
|
|
{ |
|
|
|
_packetByIdentifier.TryRemove(withIdent.PacketIdentifier, out var _); |
|
|
|
var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>()); |
|
|
|
byId.TryRemove(withIdent.PacketIdentifier, out var _); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|