diff --git a/MQTTnet.Core/Client/MqttPacketDispatcher.cs b/MQTTnet.Core/Client/MqttPacketDispatcher.cs index 1551d3f..a825581 100644 --- a/MQTTnet.Core/Client/MqttPacketDispatcher.cs +++ b/MQTTnet.Core/Client/MqttPacketDispatcher.cs @@ -14,7 +14,7 @@ namespace MQTTnet.Core.Client private readonly object _syncRoot = new object(); private readonly HashSet _receivedPackets = new HashSet(); private readonly ConcurrentDictionary> _packetByResponseType = new ConcurrentDictionary>(); - private readonly ConcurrentDictionary> _packetByIdentifier = new ConcurrentDictionary>(); + private readonly ConcurrentDictionary>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary>>(); public async Task WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) { @@ -42,17 +42,21 @@ namespace MQTTnet.Core.Client { if (packet == null) throw new ArgumentNullException(nameof(packet)); + var type = packet.GetType(); var packetDispatched = false; if (packet is IMqttPacketWithIdentifier withIdentifier) { - if (_packetByIdentifier.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) + if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) { - tcs.TrySetResult(packet); - packetDispatched = true; + if (byid.TryRemove( withIdentifier.PacketIdentifier, out var tcs)) + { + tcs.TrySetResult( packet ); + packetDispatched = true; + } } } - else if (_packetByResponseType.TryRemove(packet.GetType(), out var tcs)) + else if (_packetByResponseType.TryRemove(type, out var tcs)) { tcs.TrySetResult(packet); packetDispatched = true; @@ -78,7 +82,7 @@ namespace MQTTnet.Core.Client _receivedPackets.Clear(); } - _packetByIdentifier.Clear(); + _packetByResponseTypeAndIdentifier.Clear(); } private TaskCompletionSource AddPacketAwaiter(MqttBasePacket request, Type responseType) @@ -86,7 +90,8 @@ namespace MQTTnet.Core.Client var tcs = new TaskCompletionSource(); if (request is IMqttPacketWithIdentifier withIdent) { - _packetByIdentifier[withIdent.PacketIdentifier] = tcs; + var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary>()); + byId[withIdent.PacketIdentifier] = tcs; } else { @@ -100,7 +105,8 @@ namespace MQTTnet.Core.Client { if (request is IMqttPacketWithIdentifier withIdent) { - _packetByIdentifier.TryRemove(withIdent.PacketIdentifier, out var _); + var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary>()); + byId.TryRemove(withIdent.PacketIdentifier, out var _); } else {