diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs index f71ada1..ffdc7f2 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs @@ -335,7 +335,13 @@ namespace MQTTnet.Client private async Task SendAndReceiveAsync(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket { - var packetAwaiter = _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.CommunicationTimeout); + ushort identifier = 0; + if (requestPacket is IMqttPacketWithIdentifier requestPacketWithIdentifier) + { + identifier = requestPacketWithIdentifier.PacketIdentifier; + } + + var packetAwaiter = _packetDispatcher.WaitForPacketAsync(typeof(TResponsePacket), identifier, _options.CommunicationTimeout); await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, requestPacket).ConfigureAwait(false); return (TResponsePacket)await packetAwaiter.ConfigureAwait(false); } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs index 2404bb6..ca75931 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs @@ -10,8 +10,7 @@ namespace MQTTnet.Client { public class MqttPacketDispatcher { - private readonly ConcurrentDictionary> _packetByResponseType = new ConcurrentDictionary>(); - private readonly ConcurrentDictionary>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary>>(); + private readonly ConcurrentDictionary>> _awaiters = new ConcurrentDictionary>>(); private readonly IMqttNetLogger _logger; public MqttPacketDispatcher(IMqttNetLogger logger) @@ -19,11 +18,9 @@ namespace MQTTnet.Client _logger = logger ?? throw new ArgumentNullException(nameof(logger)); } - public async Task WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) + public async Task WaitForPacketAsync(Type responseType, ushort identifier, TimeSpan timeout) { - if (request == null) throw new ArgumentNullException(nameof(request)); - - var packetAwaiter = AddPacketAwaiter(request, responseType); + var packetAwaiter = AddPacketAwaiter(responseType, identifier); try { return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false); @@ -35,7 +32,7 @@ namespace MQTTnet.Client } finally { - RemovePacketAwaiter(request, responseType); + RemovePacketAwaiter(responseType, identifier); } } @@ -44,21 +41,20 @@ namespace MQTTnet.Client if (packet == null) throw new ArgumentNullException(nameof(packet)); var type = packet.GetType(); - if (packet is IMqttPacketWithIdentifier withIdentifier) + + if (_awaiters.TryGetValue(type, out var byId)) { - if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) + ushort identifier = 0; + if (packet is IMqttPacketWithIdentifier packetWithIdentifier) { - if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) - { - tcs.TrySetResult(packet); - return; - } + identifier = packetWithIdentifier.PacketIdentifier; + } + + if (byId.TryRemove(identifier, out var tcs)) + { + tcs.TrySetResult(packet); + return; } - } - else if (_packetByResponseType.TryRemove(type, out var tcs)) - { - tcs.TrySetResult(packet); - return; } throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched."); @@ -66,38 +62,26 @@ namespace MQTTnet.Client public void Reset() { - _packetByResponseTypeAndIdentifier.Clear(); - _packetByResponseType.Clear(); + _awaiters.Clear(); } - private TaskCompletionSource AddPacketAwaiter(MqttBasePacket request, Type responseType) + private TaskCompletionSource AddPacketAwaiter(Type responseType, ushort identifier) { var tcs = new TaskCompletionSource(); - if (request is IMqttPacketWithIdentifier requestWithIdentifier) + var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary>()); + if (!byId.TryAdd(identifier, tcs)) { - var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary>()); - byId[requestWithIdentifier.PacketIdentifier] = tcs; + throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}."); } - else - { - _packetByResponseType[responseType] = tcs; - } - + return tcs; } - private void RemovePacketAwaiter(MqttBasePacket request, Type responseType) + private void RemovePacketAwaiter(Type responseType, ushort identifier) { - if (request is IMqttPacketWithIdentifier requestWithIdentifier) - { - var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary>()); - byId.TryRemove(requestWithIdentifier.PacketIdentifier, out var _); - } - else - { - _packetByResponseType.TryRemove(responseType, out var _); - } + var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary>()); + byId.TryRemove(identifier, out var _); } } } \ No newline at end of file