You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

121 lines
4.0 KiB

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Threading.Tasks;
  4. using MQTTnet.Core.Diagnostics;
  5. using MQTTnet.Core.Exceptions;
  6. using MQTTnet.Core.Packets;
  7. using System.Collections.Concurrent;
  8. namespace MQTTnet.Core.Client
  9. {
  10. public class MqttPacketDispatcher
  11. {
  12. private readonly object _syncRoot = new object();
  13. private readonly HashSet<MqttBasePacket> _receivedPackets = new HashSet<MqttBasePacket>();
  14. private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>();
  15. private readonly ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>> _packetByIdentifier = new ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>();
  16. public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout)
  17. {
  18. if (request == null) throw new ArgumentNullException(nameof(request));
  19. var packetAwaiter = AddPacketAwaiter(request, responseType);
  20. DispatchPendingPackets();
  21. var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task;
  22. RemovePacketAwaiter(request, responseType);
  23. if (hasTimeout)
  24. {
  25. MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet.");
  26. throw new MqttCommunicationTimedOutException();
  27. }
  28. return packetAwaiter.Task.Result;
  29. }
  30. public void Dispatch(MqttBasePacket packet)
  31. {
  32. if (packet == null) throw new ArgumentNullException(nameof(packet));
  33. var packetDispatched = false;
  34. if (packet is IMqttPacketWithIdentifier withIdentifier)
  35. {
  36. if (_packetByIdentifier.TryRemove(withIdentifier.PacketIdentifier, out var tcs))
  37. {
  38. tcs.TrySetResult(packet);
  39. packetDispatched = true;
  40. }
  41. }
  42. else if (_packetByResponseType.TryRemove(packet.GetType(), out var tcs) )
  43. {
  44. tcs.TrySetResult( packet);
  45. packetDispatched = true;
  46. }
  47. lock (_syncRoot)
  48. {
  49. if (!packetDispatched)
  50. {
  51. _receivedPackets.Add(packet);
  52. }
  53. else
  54. {
  55. _receivedPackets.Remove(packet);
  56. }
  57. }
  58. }
  59. public void Reset()
  60. {
  61. lock (_syncRoot)
  62. {
  63. _receivedPackets.Clear();
  64. }
  65. _packetByIdentifier.Clear();
  66. }
  67. private TaskCompletionSource<MqttBasePacket> AddPacketAwaiter(MqttBasePacket request, Type responseType)
  68. {
  69. var tcs = new TaskCompletionSource<MqttBasePacket>();
  70. if (request is IMqttPacketWithIdentifier withIdent)
  71. {
  72. _packetByIdentifier[withIdent.PacketIdentifier] = tcs;
  73. }
  74. else
  75. {
  76. _packetByResponseType[responseType] = tcs;
  77. }
  78. return tcs;
  79. }
  80. private void RemovePacketAwaiter(MqttBasePacket request, Type responseType)
  81. {
  82. if (request is IMqttPacketWithIdentifier withIdent)
  83. {
  84. _packetByIdentifier.TryRemove(withIdent.PacketIdentifier, out var tcs);
  85. }
  86. else
  87. {
  88. _packetByResponseType.TryRemove(responseType, out var tcs);
  89. }
  90. }
  91. private void DispatchPendingPackets()
  92. {
  93. List<MqttBasePacket> receivedPackets;
  94. lock (_syncRoot)
  95. {
  96. receivedPackets = new List<MqttBasePacket>(_receivedPackets);
  97. }
  98. foreach (var pendingPacket in receivedPackets)
  99. {
  100. Dispatch(pendingPacket);
  101. }
  102. }
  103. }
  104. }