using System; using System.Collections.Concurrent; using System.Text; using System.Threading; using System.Threading.Tasks; using MQTTnet.Client; using MQTTnet.Exceptions; using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { public class MqttRpcClient : IDisposable { private readonly ConcurrentDictionary> _waitingCalls = new ConcurrentDictionary>(); private readonly IMqttClient _mqttClient; private readonly RpcAwareApplicationMessageReceivedHandler _applicationMessageReceivedHandler; public MqttRpcClient(IMqttClient mqttClient) { _mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient)); _applicationMessageReceivedHandler = new RpcAwareApplicationMessageReceivedHandler( mqttClient.ApplicationMessageReceivedHandler, HandleApplicationMessageReceivedAsync); _mqttClient.ApplicationMessageReceivedHandler = _applicationMessageReceivedHandler; } public Task ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel) { return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, CancellationToken.None); } public Task ExecuteAsync(TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) { return ExecuteAsync(timeout, methodName, Encoding.UTF8.GetBytes(payload), qualityOfServiceLevel, cancellationToken); } public Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel) { return ExecuteAsync(timeout, methodName, payload, qualityOfServiceLevel, CancellationToken.None); } public async Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, CancellationToken cancellationToken) { if (methodName == null) throw new ArgumentNullException(nameof(methodName)); if (methodName.Contains("/") || methodName.Contains("+") || methodName.Contains("#")) { throw new ArgumentException("The method name cannot contain /, + or #."); } if (!(_mqttClient.ApplicationMessageReceivedHandler is RpcAwareApplicationMessageReceivedHandler)) { throw new InvalidOperationException("The application message received handler was modified."); } var requestTopic = $"MQTTnet.RPC/{Guid.NewGuid():N}/{methodName}"; var responseTopic = requestTopic + "/response"; var requestMessage = new MqttApplicationMessageBuilder() .WithTopic(requestTopic) .WithPayload(payload) .WithQualityOfServiceLevel(qualityOfServiceLevel) .Build(); try { var tcs = new TaskCompletionSource(); if (!_waitingCalls.TryAdd(responseTopic, tcs)) { throw new InvalidOperationException(); } await _mqttClient.SubscribeAsync(responseTopic, qualityOfServiceLevel).ConfigureAwait(false); await _mqttClient.PublishAsync(requestMessage).ConfigureAwait(false); using (var timeoutCts = new CancellationTokenSource(timeout)) using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token)) { linkedCts.Token.Register(() => { if (!tcs.Task.IsCompleted && !tcs.Task.IsFaulted && !tcs.Task.IsCanceled) { tcs.TrySetCanceled(); } }); try { var result = await tcs.Task.ConfigureAwait(false); timeoutCts.Cancel(false); return result; } catch (OperationCanceledException exception) { if (timeoutCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) { throw new MqttCommunicationTimedOutException(exception); } else { throw; } } } } finally { _waitingCalls.TryRemove(responseTopic, out _); await _mqttClient.UnsubscribeAsync(responseTopic).ConfigureAwait(false); } } private Task HandleApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventArgs eventArgs) { if (!_waitingCalls.TryRemove(eventArgs.ApplicationMessage.Topic, out var tcs)) { return Task.FromResult(0); } tcs.TrySetResult(eventArgs.ApplicationMessage.Payload); return Task.FromResult(0); } public void Dispose() { _mqttClient.ApplicationMessageReceivedHandler = _applicationMessageReceivedHandler.OriginalHandler; foreach (var tcs in _waitingCalls) { tcs.Value.TrySetCanceled(); } _waitingCalls.Clear(); } } }