@@ -10,8 +10,7 @@ namespace MQTTnet.Client
{
{
public class MqttPacketDispatcher
public class MqttPacketDispatcher
{
{
private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>();
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>();
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _awaiters = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>();
private readonly IMqttNetLogger _logger;
private readonly IMqttNetLogger _logger;
public MqttPacketDispatcher(IMqttNetLogger logger)
public MqttPacketDispatcher(IMqttNetLogger logger)
@@ -19,11 +18,9 @@ namespace MQTTnet.Client
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
}
public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout)
public async Task<MqttBasePacket> WaitForPacketAsync(Type responseType, ushort identifier , TimeSpan timeout)
{
{
if (request == null) throw new ArgumentNullException(nameof(request));
var packetAwaiter = AddPacketAwaiter(request, responseType);
var packetAwaiter = AddPacketAwaiter(responseType, identifier);
try
try
{
{
return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false);
return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false);
@@ -35,7 +32,7 @@ namespace MQTTnet.Client
}
}
finally
finally
{
{
RemovePacketAwaiter(request, re sponseType);
RemovePacketAwaiter(responseType, identifier );
}
}
}
}
@@ -44,21 +41,20 @@ namespace MQTTnet.Client
if (packet == null) throw new ArgumentNullException(nameof(packet));
if (packet == null) throw new ArgumentNullException(nameof(packet));
var type = packet.GetType();
var type = packet.GetType();
if (packet is IMqttPacketWithIdentifier withIdentifier)
if (_awaiters.TryGetValue(type, out var byId))
{
{
if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid))
ushort identifier = 0;
if (packet is IMqttPacketWithIdentifier packetWithIdentifier)
{
{
if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs))
{
tcs.TrySetResult(packet);
return;
}
identifier = packetWithIdentifier.PacketIdentifier;
}
if (byId.TryRemove(identifier, out var tcs))
{
tcs.TrySetResult(packet);
return;
}
}
}
else if (_packetByResponseType.TryRemove(type, out var tcs))
{
tcs.TrySetResult(packet);
return;
}
}
throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched.");
throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched.");
@@ -66,38 +62,26 @@ namespace MQTTnet.Client
public void Reset()
public void Reset()
{
{
_packetByResponseTypeAndIdentifier.Clear();
_packetByResponseType.Clear();
_awaiters.Clear();
}
}
private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(MqttBasePacket request, Type responseType)
private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(Type responseType, ushort identifier )
{
{
var tcs = new TaskCompletionSource<MqttBasePacket>();
var tcs = new TaskCompletionSource<MqttBasePacket>();
if (request is IMqttPacketWithIdentifier requestWithIdentifier)
var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
if (!byId.TryAdd(identifier, tcs))
{
{
var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
byId[requestWithIdentifier.PacketIdentifier] = tcs;
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}.");
}
}
else
{
_packetByResponseType[responseType] = tcs;
}
return tcs;
return tcs;
}
}
private void RemovePacketAwaiter(MqttBasePacket request, Type responseType)
private void RemovePacketAwaiter(Type responseType, ushort identifier)
{
{
if (request is IMqttPacketWithIdentifier requestWithIdentifier)
{
var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
byId.TryRemove(requestWithIdentifier.PacketIdentifier, out var _);
}
else
{
_packetByResponseType.TryRemove(responseType, out var _);
}
var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
byId.TryRemove(identifier, out var _);
}
}
}
}
}
}