You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

96 lines
3.7 KiB

  1. using System;
  2. using System.Threading.Tasks;
  3. using MQTTnet.Core.Diagnostics;
  4. using MQTTnet.Core.Exceptions;
  5. using MQTTnet.Core.Internal;
  6. using MQTTnet.Core.Packets;
  7. using System.Collections.Concurrent;
  8. namespace MQTTnet.Core.Client
  9. {
  10. public class MqttPacketDispatcher
  11. {
  12. private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>();
  13. private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>();
  14. public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout)
  15. {
  16. if (request == null) throw new ArgumentNullException(nameof(request));
  17. var packetAwaiter = AddPacketAwaiter(request, responseType);
  18. try
  19. {
  20. return await packetAwaiter.Task.TimeoutAfter(timeout);
  21. }
  22. catch (MqttCommunicationTimedOutException)
  23. {
  24. MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet of type '{0}'.", responseType.Name);
  25. throw;
  26. }
  27. finally
  28. {
  29. RemovePacketAwaiter(request, responseType);
  30. }
  31. }
  32. public void Dispatch(MqttBasePacket packet)
  33. {
  34. if (packet == null) throw new ArgumentNullException(nameof(packet));
  35. var type = packet.GetType();
  36. if (packet is IMqttPacketWithIdentifier withIdentifier)
  37. {
  38. if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid))
  39. {
  40. if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs))
  41. {
  42. tcs.TrySetResult(packet);
  43. return;
  44. }
  45. }
  46. }
  47. else if (_packetByResponseType.TryRemove(type, out var tcs))
  48. {
  49. tcs.TrySetResult(packet);
  50. return;
  51. }
  52. throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched.");
  53. }
  54. public void Reset()
  55. {
  56. _packetByResponseTypeAndIdentifier.Clear();
  57. _packetByResponseType.Clear();
  58. }
  59. private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(MqttBasePacket request, Type responseType)
  60. {
  61. var tcs = new TaskCompletionSource<MqttBasePacket>();
  62. if (request is IMqttPacketWithIdentifier withIdent)
  63. {
  64. var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
  65. byId[withIdent.PacketIdentifier] = tcs;
  66. }
  67. else
  68. {
  69. _packetByResponseType[responseType] = tcs;
  70. }
  71. return tcs;
  72. }
  73. private void RemovePacketAwaiter(MqttBasePacket request, Type responseType)
  74. {
  75. if (request is IMqttPacketWithIdentifier withIdent)
  76. {
  77. var byId = _packetByResponseTypeAndIdentifier.GetOrAdd(responseType, key => new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>());
  78. byId.TryRemove(withIdent.PacketIdentifier, out var _);
  79. }
  80. else
  81. {
  82. _packetByResponseType.TryRemove(responseType, out var _);
  83. }
  84. }
  85. }
  86. }