Browse Source

Optimize dispatcher

release/3.x.x
Christian 6 years ago
parent
commit
90098c2006
2 changed files with 31 additions and 41 deletions
  1. +7
    -1
      Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs
  2. +24
    -40
      Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs

+ 7
- 1
Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs View File

@@ -335,7 +335,13 @@ namespace MQTTnet.Client


private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket
{ {
var packetAwaiter = _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.CommunicationTimeout);
ushort identifier = 0;
if (requestPacket is IMqttPacketWithIdentifier requestPacketWithIdentifier)
{
identifier = requestPacketWithIdentifier.PacketIdentifier;
}

var packetAwaiter = _packetDispatcher.WaitForPacketAsync(typeof(TResponsePacket), identifier, _options.CommunicationTimeout);
await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, requestPacket).ConfigureAwait(false); await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, requestPacket).ConfigureAwait(false);
return (TResponsePacket)await packetAwaiter.ConfigureAwait(false); return (TResponsePacket)await packetAwaiter.ConfigureAwait(false);
} }


+ 24
- 40
Frameworks/MQTTnet.NetStandard/Client/MqttPacketDispatcher.cs View File

@@ -10,8 +10,7 @@ namespace MQTTnet.Client
{ {
public class MqttPacketDispatcher public class MqttPacketDispatcher
{ {
private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>();
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>();
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _awaiters = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>();
private readonly IMqttNetLogger _logger; private readonly IMqttNetLogger _logger;


public MqttPacketDispatcher(IMqttNetLogger logger) public MqttPacketDispatcher(IMqttNetLogger logger)
@@ -19,11 +18,9 @@ namespace MQTTnet.Client
_logger = logger ?? throw new ArgumentNullException(nameof(logger)); _logger = logger ?? throw new ArgumentNullException(nameof(logger));
} }


public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout)
public async Task<MqttBasePacket> WaitForPacketAsync(Type responseType, ushort identifier, TimeSpan timeout)
{ {
if (request == null) throw new ArgumentNullException(nameof(request));

var packetAwaiter = AddPacketAwaiter(request, responseType);
var packetAwaiter = AddPacketAwaiter(responseType, identifier);
try try
{ {
return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false); return await packetAwaiter.Task.TimeoutAfter(timeout).ConfigureAwait(false);
@@ -35,7 +32,7 @@ namespace MQTTnet.Client
} }
finally finally
{ {
RemovePacketAwaiter(request, responseType);
RemovePacketAwaiter(responseType, identifier);
} }
} }


@@ -44,21 +41,20 @@ namespace MQTTnet.Client
if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packet == null) throw new ArgumentNullException(nameof(packet));


var type = packet.GetType(); var type = packet.GetType();
if (packet is IMqttPacketWithIdentifier withIdentifier)

if (_awaiters.TryGetValue(type, out var byId))
{ {
if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid))
ushort identifier = 0;
if (packet is IMqttPacketWithIdentifier packetWithIdentifier)
{ {
if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs))
{
tcs.TrySetResult(packet);
return;
}
identifier = packetWithIdentifier.PacketIdentifier;
}

if (byId.TryRemove(identifier, out var tcs))
{
tcs.TrySetResult(packet);
return;
} }
}
else if (_packetByResponseType.TryRemove(type, out var tcs))
{
tcs.TrySetResult(packet);
return;
} }


throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched."); throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched.");
@@ -66,38 +62,26 @@ namespace MQTTnet.Client


public void Reset() public void Reset()
{ {
_packetByResponseTypeAndIdentifier.Clear();
_packetByResponseType.Clear();
_awaiters.Clear();
} }


private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(MqttBasePacket request, Type responseType)
private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(Type responseType, ushort identifier)
{ {
var tcs = new TaskCompletionSource<MqttBasePacket>(); var tcs = new TaskCompletionSource<MqttBasePacket>();


if (request is IMqttPacketWithIdentifier requestWithIdentifier)
var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
if (!byId.TryAdd(identifier, tcs))
{ {
var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
byId[requestWithIdentifier.PacketIdentifier] = tcs;
throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{responseType}' with identifier {identifier}.");
} }
else
{
_packetByResponseType[responseType] = tcs;
}

return tcs; return tcs;
} }


private void RemovePacketAwaiter(MqttBasePacket request, Type responseType)
private void RemovePacketAwaiter(Type responseType, ushort identifier)
{ {
if (request is IMqttPacketWithIdentifier requestWithIdentifier)
{
var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
byId.TryRemove(requestWithIdentifier.PacketIdentifier, out var _);
}
else
{
_packetByResponseType.TryRemove(responseType, out var _);
}
var byId = _awaiters.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
byId.TryRemove(identifier, out var _);
} }
} }
} }

Loading…
Cancel
Save