diff --git a/MQTTnet.Core/Client/MqttClient.cs b/MQTTnet.Core/Client/MqttClient.cs index 29b44cf..0480ffe 100644 --- a/MQTTnet.Core/Client/MqttClient.cs +++ b/MQTTnet.Core/Client/MqttClient.cs @@ -321,23 +321,9 @@ namespace MQTTnet.Core.Client private async Task SendAndReceiveAsync(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket { - bool ResponsePacketSelector(MqttBasePacket p) - { - if (!(p is TResponsePacket p1)) - { - return false; - } - - if (!(requestPacket is IMqttPacketWithIdentifier pi1) || !(p is IMqttPacketWithIdentifier pi2)) - { - return true; - } - - return pi1.PacketIdentifier == pi2.PacketIdentifier; - } - await _adapter.SendPacketAsync(requestPacket, _options.DefaultCommunicationTimeout).ConfigureAwait(false); - return (TResponsePacket)await _packetDispatcher.WaitForPacketAsync(ResponsePacketSelector, _options.DefaultCommunicationTimeout).ConfigureAwait(false); + + return (TResponsePacket)await _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.DefaultCommunicationTimeout).ConfigureAwait(false); } private ushort GetNewPacketIdentifier() diff --git a/MQTTnet.Core/Client/MqttPacketAwaiter.cs b/MQTTnet.Core/Client/MqttPacketAwaiter.cs deleted file mode 100644 index b7a6555..0000000 --- a/MQTTnet.Core/Client/MqttPacketAwaiter.cs +++ /dev/null @@ -1,16 +0,0 @@ -using System; -using System.Threading.Tasks; -using MQTTnet.Core.Packets; - -namespace MQTTnet.Core.Client -{ - public class MqttPacketAwaiter : TaskCompletionSource - { - public MqttPacketAwaiter(Func packetSelector) - { - PacketSelector = packetSelector ?? throw new ArgumentNullException(nameof(packetSelector)); - } - - public Func PacketSelector { get; } - } -} \ No newline at end of file diff --git a/MQTTnet.Core/Client/MqttPacketDispatcher.cs b/MQTTnet.Core/Client/MqttPacketDispatcher.cs index aba050d..f057b97 100644 --- a/MQTTnet.Core/Client/MqttPacketDispatcher.cs +++ b/MQTTnet.Core/Client/MqttPacketDispatcher.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Packets; +using System.Collections.Concurrent; namespace MQTTnet.Core.Client { @@ -11,17 +12,18 @@ namespace MQTTnet.Core.Client { private readonly object _syncRoot = new object(); private readonly HashSet _receivedPackets = new HashSet(); - private readonly List _packetAwaiters = new List(); + private readonly ConcurrentDictionary> _packetByResponseType = new ConcurrentDictionary>(); + private readonly ConcurrentDictionary> _packetByIdentifier = new ConcurrentDictionary>(); - public async Task WaitForPacketAsync(Func selector, TimeSpan timeout) + public async Task WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) { - if (selector == null) throw new ArgumentNullException(nameof(selector)); + if (request == null) throw new ArgumentNullException(nameof(request)); - var packetAwaiter = AddPacketAwaiter(selector); + var packetAwaiter = AddPacketAwaiter(request, responseType); DispatchPendingPackets(); var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; - RemovePacketAwaiter(packetAwaiter); + RemovePacketAwaiter(request, responseType); if (hasTimeout) { @@ -37,15 +39,20 @@ namespace MQTTnet.Core.Client if (packet == null) throw new ArgumentNullException(nameof(packet)); var packetDispatched = false; - foreach (var packetAwaiter in GetPacketAwaiters()) + + if (packet is IMqttPacketWithIdentifier withIdentifier) { - if (packetAwaiter.PacketSelector(packet)) + if (_packetByIdentifier.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) { - packetAwaiter.TrySetResult(packet); + tcs.TrySetResult(packet); packetDispatched = true; - break; } } + else if (_packetByResponseType.TryRemove(packet.GetType(), out var tcs) ) + { + tcs.TrySetResult( packet); + packetDispatched = true; + } lock (_syncRoot) { @@ -64,34 +71,36 @@ namespace MQTTnet.Core.Client { lock (_syncRoot) { - _packetAwaiters.Clear(); _receivedPackets.Clear(); } + + _packetByIdentifier.Clear(); } - private List GetPacketAwaiters() + private TaskCompletionSource AddPacketAwaiter(MqttBasePacket request, Type responseType) { - lock (_syncRoot) + var tcs = new TaskCompletionSource(); + if (request is IMqttPacketWithIdentifier withIdent) { - return new List(_packetAwaiters); + _packetByIdentifier[withIdent.PacketIdentifier] = tcs; } - } - - private MqttPacketAwaiter AddPacketAwaiter(Func selector) - { - lock (_syncRoot) + else { - var packetAwaiter = new MqttPacketAwaiter(selector); - _packetAwaiters.Add(packetAwaiter); - return packetAwaiter; + _packetByResponseType[responseType] = tcs; } + + return tcs; } - private void RemovePacketAwaiter(MqttPacketAwaiter packetAwaiter) + private void RemovePacketAwaiter(MqttBasePacket request, Type responseType) { - lock (_syncRoot) + if (request is IMqttPacketWithIdentifier withIdent) + { + _packetByIdentifier.TryRemove(withIdent.PacketIdentifier, out var tcs); + } + else { - _packetAwaiters.Remove(packetAwaiter); + _packetByResponseType.TryRemove(responseType, out var tcs); } }