using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; namespace MQTTnet.Client { public class MqttClient : IMqttClient { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly Stopwatch _sendTracker = new Stopwatch(); private readonly SemaphoreSlim _disconnectLock = new SemaphoreSlim(1, 1); private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly IMqttClientAdapterFactory _adapterFactory; private readonly IMqttNetLogger _logger; private IMqttClientOptions _options; private CancellationTokenSource _cancellationTokenSource; private Task _packetReceiverTask; private Task _keepAliveMessageSenderTask; private IMqttChannelAdapter _adapter; public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger) { _adapterFactory = channelFactory ?? throw new ArgumentNullException(nameof(channelFactory)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); } public event EventHandler Connected; public event EventHandler Disconnected; public event EventHandler ApplicationMessageReceived; public bool IsConnected { get; private set; } public async Task ConnectAsync(IMqttClientOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); if (options.ChannelOptions == null) throw new ArgumentException("ChannelOptions are not set."); ThrowIfConnected("It is not allowed to connect with a server after the connection is established."); try { _cancellationTokenSource = new CancellationTokenSource(); _options = options; _packetIdentifierProvider.Reset(); _packetDispatcher.Reset(); _adapter = _adapterFactory.CreateClientAdapter(options, _logger); _logger.Verbose("Trying to connect with server."); await _adapter.ConnectAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token).ConfigureAwait(false); _logger.Verbose("Connection with server established."); StartReceivingPackets(_cancellationTokenSource.Token); var connectResponse = await AuthenticateAsync(options.WillMessage, _cancellationTokenSource.Token).ConfigureAwait(false); _logger.Verbose("MQTT connection with server established."); _sendTracker.Restart(); if (_options.KeepAlivePeriod != TimeSpan.Zero) { StartSendingKeepAliveMessages(_cancellationTokenSource.Token); } IsConnected = true; Connected?.Invoke(this, new MqttClientConnectedEventArgs(connectResponse.IsSessionPresent)); _logger.Info("Connected."); return new MqttClientConnectResult(connectResponse.IsSessionPresent); } catch (Exception exception) { _logger.Error(exception, "Error while connecting with server."); await DisconnectInternalAsync(null, exception).ConfigureAwait(false); throw; } } public async Task DisconnectAsync() { try { if (IsConnected && !_cancellationTokenSource.IsCancellationRequested) { await SendAsync(new MqttDisconnectPacket(), _cancellationTokenSource.Token).ConfigureAwait(false); } } finally { await DisconnectInternalAsync(null, null).ConfigureAwait(false); } } public async Task> SubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); ThrowIfNotConnected(); var subscribePacket = new MqttSubscribePacket { PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(), TopicFilters = topicFilters.ToList() }; var response = await SendAndReceiveAsync(subscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); if (response.SubscribeReturnCodes.Count != subscribePacket.TopicFilters.Count) { throw new MqttProtocolViolationException("The return codes are not matching the topic filters [MQTT-3.9.3-1]."); } return subscribePacket.TopicFilters.Select((t, i) => new MqttSubscribeResult(t, response.SubscribeReturnCodes[i])).ToList(); } public async Task UnsubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); ThrowIfNotConnected(); var unsubscribePacket = new MqttUnsubscribePacket { PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(), TopicFilters = topicFilters.ToList() }; await SendAndReceiveAsync(unsubscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); } public async Task PublishAsync(IEnumerable applicationMessages) { ThrowIfNotConnected(); var publishPackets = applicationMessages.Select(m => m.ToPublishPacket()); var packetGroups = publishPackets.GroupBy(p => p.QualityOfServiceLevel).OrderBy(g => g.Key); foreach (var qosGroup in packetGroups) { switch (qosGroup.Key) { case MqttQualityOfServiceLevel.AtMostOnce: { // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] await SendAsync(qosGroup, _cancellationTokenSource.Token).ConfigureAwait(false); break; } case MqttQualityOfServiceLevel.AtLeastOnce: { foreach (var publishPacket in qosGroup) { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); await SendAndReceiveAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); } break; } case MqttQualityOfServiceLevel.ExactlyOnce: { foreach (var publishPacket in qosGroup) { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); var pubRecPacket = await SendAndReceiveAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); var pubRelPacket = new MqttPubRelPacket { PacketIdentifier = pubRecPacket.PacketIdentifier }; await SendAndReceiveAsync(pubRelPacket, _cancellationTokenSource.Token).ConfigureAwait(false); } break; } default: { throw new InvalidOperationException(); } } } } public void Dispose() { _cancellationTokenSource?.Dispose(); _cancellationTokenSource = null; _adapter?.Dispose(); } private async Task AuthenticateAsync(MqttApplicationMessage willApplicationMessage, CancellationToken cancellationToken) { var connectPacket = new MqttConnectPacket { ClientId = _options.ClientId, Username = _options.Credentials?.Username, Password = _options.Credentials?.Password, CleanSession = _options.CleanSession, KeepAlivePeriod = (ushort)_options.KeepAlivePeriod.TotalSeconds, WillMessage = willApplicationMessage }; var response = await SendAndReceiveAsync(connectPacket, cancellationToken).ConfigureAwait(false); if (response.ConnectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { throw new MqttConnectingFailedException(response.ConnectReturnCode); } return response; } private void ThrowIfNotConnected() { if (!IsConnected) throw new MqttCommunicationException("The client is not connected."); } private void ThrowIfConnected(string message) { if (IsConnected) throw new MqttProtocolViolationException(message); } private async Task DisconnectInternalAsync(Task sender, Exception exception) { await _disconnectLock.WaitAsync(); try { if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested) { return; } _cancellationTokenSource.Cancel(false); } catch (Exception adapterException) { _logger.Warning(adapterException, "Error while disconnecting from adapter."); } finally { _disconnectLock.Release(); } var clientWasConnected = IsConnected; IsConnected = false; try { await WaitForTaskAsync(_packetReceiverTask, sender).ConfigureAwait(false); await WaitForTaskAsync(_keepAliveMessageSenderTask, sender).ConfigureAwait(false); if (_keepAliveMessageSenderTask != null && _keepAliveMessageSenderTask != sender) { await _keepAliveMessageSenderTask.ConfigureAwait(false); } if (_adapter != null) { await _adapter.DisconnectAsync(_options.CommunicationTimeout, CancellationToken.None).ConfigureAwait(false); } _logger.Verbose("Disconnected from adapter."); } catch (Exception adapterException) { _logger.Warning(adapterException, "Error while disconnecting from adapter."); } finally { _adapter?.Dispose(); _adapter = null; _cancellationTokenSource?.Dispose(); _cancellationTokenSource = null; _logger.Info("Disconnected."); Disconnected?.Invoke(this, new MqttClientDisconnectedEventArgs(clientWasConnected, exception)); } } private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) { return SendAsync(new[] { packet }, cancellationToken); } private Task SendAsync(IEnumerable packets, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { throw new TaskCanceledException(); } _sendTracker.Restart(); return _adapter.SendPacketsAsync(_options.CommunicationTimeout, packets, cancellationToken); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket { if (cancellationToken.IsCancellationRequested) { throw new TaskCanceledException(); } _sendTracker.Restart(); ushort identifier = 0; if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier.HasValue) { identifier = packetWithIdentifier.PacketIdentifier.Value; } var packetAwaiter = _packetDispatcher.AddPacketAwaiter(identifier); try { await _adapter.SendPacketsAsync(_options.CommunicationTimeout, new[] { requestPacket }, cancellationToken).ConfigureAwait(false); var respone = await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); return (TResponsePacket)respone; } catch (MqttCommunicationTimedOutException) { _logger.Warning($"Timeout while waiting for packet of type '{typeof(TResponsePacket).Namespace}'."); throw; } finally { _packetDispatcher.RemovePacketAwaiter(identifier); } } private async Task SendKeepAliveMessagesAsync(CancellationToken cancellationToken) { _logger.Verbose("Start sending keep alive packets."); try { while (!cancellationToken.IsCancellationRequested) { var keepAliveSendInterval = TimeSpan.FromSeconds(_options.KeepAlivePeriod.TotalSeconds * 0.75); if (_options.KeepAliveSendInterval.HasValue) { keepAliveSendInterval = _options.KeepAliveSendInterval.Value; } if (_sendTracker.Elapsed > keepAliveSendInterval) { await SendAndReceiveAsync(new MqttPingReqPacket(), cancellationToken).ConfigureAwait(false); } await Task.Delay(keepAliveSendInterval, cancellationToken).ConfigureAwait(false); } } catch (Exception exception) { if (exception is OperationCanceledException) { } else if (exception is MqttCommunicationException) { _logger.Warning(exception, "MQTT communication exception while sending/receiving keep alive packets."); } else { _logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets."); } await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); } finally { _logger.Verbose("Stopped sending keep alive packets."); } } private async Task ReceivePacketsAsync(CancellationToken cancellationToken) { _logger.Verbose("Start receiving packets."); try { while (!cancellationToken.IsCancellationRequested) { var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); if (cancellationToken.IsCancellationRequested) { return; } if (packet == null) { continue; } if (_options.ReceivedApplicationMessageProcessingMode == MqttReceivedApplicationMessageProcessingMode.SingleThread) { await ProcessReceivedPacketAsync(packet, cancellationToken).ConfigureAwait(false); } else if (_options.ReceivedApplicationMessageProcessingMode == MqttReceivedApplicationMessageProcessingMode.DedicatedThread) { StartProcessReceivedPacketAsync(packet, cancellationToken); } } } catch (Exception exception) { if (exception is OperationCanceledException) { } else if (exception is MqttCommunicationException) { _logger.Warning(exception, "MQTT communication exception while receiving packets."); } else { _logger.Error(exception, "Unhandled exception while receiving packets."); } await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); _packetDispatcher.Dispatch(exception); } finally { _logger.Verbose("Stopped receiving packets."); } } private async Task ProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { try { if (packet is MqttPublishPacket publishPacket) { await ProcessReceivedPublishPacketAsync(publishPacket, cancellationToken).ConfigureAwait(false); return; } if (packet is MqttPingReqPacket) { await SendAsync(new MqttPingRespPacket(), cancellationToken).ConfigureAwait(false); return; } if (packet is MqttDisconnectPacket) { await DisconnectAsync().ConfigureAwait(false); return; } if (packet is MqttPubRelPacket pubRelPacket) { await ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken).ConfigureAwait(false); return; } _packetDispatcher.Dispatch(packet); } catch (Exception exception) { _logger.Error(exception, "Unhandled exception while processing received packet."); } } private Task ProcessReceivedPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) { FireApplicationMessageReceivedEvent(publishPacket); return Task.FromResult(0); } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) { FireApplicationMessageReceivedEvent(publishPacket); return SendAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }, cancellationToken); } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) { // QoS 2 is implement as method "B" [4.3.3 QoS 2: Exactly once delivery] FireApplicationMessageReceivedEvent(publishPacket); return SendAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }, cancellationToken); } throw new MqttCommunicationException("Received a not supported QoS level."); } private Task ProcessReceivedPubRelPacket(MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }; return SendAsync(response, cancellationToken); } private void StartReceivingPackets(CancellationToken cancellationToken) { _packetReceiverTask = Task.Run(() => ReceivePacketsAsync(cancellationToken), cancellationToken); } private void StartSendingKeepAliveMessages(CancellationToken cancellationToken) { _keepAliveMessageSenderTask = Task.Run(() => SendKeepAliveMessagesAsync(cancellationToken), cancellationToken); } private void StartProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed Task.Run(() => ProcessReceivedPacketAsync(packet, cancellationToken), cancellationToken); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } private void FireApplicationMessageReceivedEvent(MqttPublishPacket publishPacket) { try { var applicationMessage = publishPacket.ToApplicationMessage(); ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(_options.ClientId, applicationMessage)); } catch (Exception exception) { _logger.Error(exception, "Unhandled exception while handling application message."); } } private static async Task WaitForTaskAsync(Task task, Task sender) { if (task == sender || task == null) { return; } if (task.IsCanceled || task.IsCompleted || task.IsFaulted) { return; } try { await task.ConfigureAwait(false); } catch (TaskCanceledException) { } } } }