You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

167 lines
6.7 KiB

  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Text;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using MQTTnet.Client;
  7. using MQTTnet.Exceptions;
  8. using MQTTnet.Extensions.Rpc.Options;
  9. using MQTTnet.Extensions.Rpc.Options.TopicGeneration;
  10. using MQTTnet.Protocol;
  11. namespace MQTTnet.Extensions.Rpc
  12. {
  13. public class MqttRpcClient : IDisposable
  14. {
  15. private readonly ConcurrentDictionary<string, TaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, TaskCompletionSource<byte[]>>();
  16. private readonly IMqttClient _mqttClient;
  17. private readonly IMqttRpcClientOptions _options;
  18. private readonly RpcAwareApplicationMessageReceivedHandler _applicationMessageReceivedHandler;
  19. [Obsolete("Use MqttRpcClient(IMqttClient mqttClient, IMqttRpcClientOptions options).")]
  20. public MqttRpcClient(IMqttClient mqttClient) : this(mqttClient, new MqttRpcClientOptions())
  21. {
  22. }
  23. public MqttRpcClient(IMqttClient mqttClient, IMqttRpcClientOptions options)
  24. {
  25. _mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient));
  26. _options = options ?? throw new ArgumentNullException(nameof(options));
  27. _applicationMessageReceivedHandler = new RpcAwareApplicationMessageReceivedHandler(
  28. mqttClient.ApplicationMessageReceivedHandler,
  29. HandleApplicationMessageReceivedAsync);
  30. _mqttClient.ApplicationMessageReceivedHandler = _applicationMessageReceivedHandler;
  31. }
  32. public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
  33. {
  34. return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None);
  35. }
  36. public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken)
  37. {
  38. return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken);
  39. }
  40. public Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel)
  41. {
  42. return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None);
  43. }
  44. public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken)
  45. {
  46. if (methodName == null) throw new ArgumentNullException(nameof(methodName));
  47. if (methodName.Contains("/") || methodName.Contains("+") || methodName.Contains("#"))
  48. {
  49. throw new ArgumentException("The method name cannot contain /, + or #.");
  50. }
  51. if (!(_mqttClient.ApplicationMessageReceivedHandler is RpcAwareApplicationMessageReceivedHandler))
  52. {
  53. throw new InvalidOperationException("The application message received handler was modified.");
  54. }
  55. var topicNames = _options.TopicGenerationStrategy.CreateRpcTopics(new TopicGenerationContext
  56. {
  57. MethodName = methodName,
  58. QualityOfServiceLevel = qualityOfServiceLevel,
  59. MqttClient = _mqttClient,
  60. Options = _options
  61. });
  62. var requestTopic = topicNames.RequestTopic;
  63. var responseTopic = topicNames.ResponseTopic;
  64. if (string.IsNullOrWhiteSpace(requestTopic))
  65. {
  66. throw new MqttProtocolViolationException("RPC request topic is empty.");
  67. }
  68. if (string.IsNullOrWhiteSpace(responseTopic))
  69. {
  70. throw new MqttProtocolViolationException("RPC response topic is empty.");
  71. }
  72. var requestMessage = new MqttApplicationMessageBuilder()
  73. .WithTopic(requestTopic)
  74. .WithPayload(payload)
  75. .WithQualityOfServiceLevel(qualityOfServiceLevel)
  76. .Build();
  77. try
  78. {
  79. var tcs = new TaskCompletionSource<byte[]>();
  80. if (!_waitingCalls.TryAdd(responseTopic, tcs))
  81. {
  82. throw new InvalidOperationException();
  83. }
  84. await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false);
  85. await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false);
  86. using (var timeoutCts = new CancellationTokenSource(timeout))
  87. using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token))
  88. {
  89. linkedCts.Token.Register(() =>
  90. {
  91. if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled)
  92. {
  93. tcs.TrySetCanceled();
  94. }
  95. });
  96. try
  97. {
  98. var result = await tcs.Task.ConfigureAwait(false);
  99. timeoutCts.Cancel(false);
  100. return result;
  101. }
  102. catch (OperationCanceledException exception)
  103. {
  104. if (timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested)
  105. {
  106. throw new MqttCommunicationTimedOutException(exception);
  107. }
  108. else
  109. {
  110. throw;
  111. }
  112. }
  113. }
  114. }
  115. finally
  116. {
  117. _waitingCalls.TryRemove(responseTopic, out _);
  118. await _mqttClient.UnsubscribeAsync(responseTopic).ConfigureAwait(false);
  119. }
  120. }
  121. private Task HandleApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventArgs eventArgs)
  122. {
  123. if (!_waitingCalls.TryRemove(eventArgs.ApplicationMessage.Topic, out var tcs))
  124. {
  125. return Task.FromResult(0);
  126. }
  127. tcs.TrySetResult(eventArgs.ApplicationMessage.Payload);
  128. return Task.FromResult(0);
  129. }
  130. public void Dispose()
  131. {
  132. _mqttClient.ApplicationMessageReceivedHandler = _applicationMessageReceivedHandler.OriginalHandler;
  133. foreach (var tcs in _waitingCalls)
  134. {
  135. tcs.Value.TrySetCanceled();
  136. }
  137. _waitingCalls.Clear();
  138. }
  139. }
  140. }