diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index d77b7f8..a4417c3 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -2,7 +2,7 @@ MQTTnet - 2.3.1 + 2.4.0 Christian Kratky Christian Kratky https://github.com/chkr1011/MQTTnet/blob/master/LICENSE @@ -10,10 +10,12 @@ https://raw.githubusercontent.com/chkr1011/MQTTnet/master/Images/Logo_128x128.png false MQTTnet is a .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker). - * [Server] Fixed an issue when accepting a new connection (UWP only) (Thanks to haeberle) -[Core] Fixed a dead lock while sending messages (Thanks to 1liveowl, JanEggers) -[Client] The client is no longer sending packets before receiving has started -[Core] Minor changes and improvements + * [Server] Added an event which is fired when a client has disconnected. +* [Server] Added support for retained application messages +* [Server] Added support for saving and loading retained messages +* [Server] The client connection is now closed if sending of one pending application message has failed +* [Server] Fixed handling of _Dup_ flag (Thanks to haeberle) +* [Core] Optimized exception handling Copyright Christian Kratky 2016-2017 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs index be6f778..c0ab6a6 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs @@ -24,7 +24,7 @@ namespace MQTTnet.Implementations public event EventHandler ClientConnected; - public void Start(MqttServerOptions options) + public Task StartAsync(MqttServerOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -57,9 +57,11 @@ namespace MQTTnet.Implementations Task.Run(() => AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); } + + return Task.FromResult(0); } - public void Stop() + public Task StopAsync() { _isRunning = false; @@ -72,11 +74,13 @@ namespace MQTTnet.Implementations _tlsEndpointSocket?.Dispose(); _tlsEndpointSocket = null; + + return Task.FromResult(0); } public void Dispose() { - Stop(); + StopAsync(); } private async Task AcceptDefaultEndpointConnectionsAsync(CancellationToken cancellationToken) diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs index 112d737..325f622 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs @@ -24,7 +24,7 @@ namespace MQTTnet.Implementations public event EventHandler ClientConnected; - public void Start(MqttServerOptions options) + public Task StartAsync(MqttServerOptions options) { if (_isRunning) throw new InvalidOperationException("Server is already started."); @@ -56,9 +56,11 @@ namespace MQTTnet.Implementations Task.Run(() => AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); } + + return Task.FromResult(0); } - public void Stop() + public Task StopAsync() { _isRunning = false; @@ -71,11 +73,13 @@ namespace MQTTnet.Implementations _tlsEndpointSocket?.Dispose(); _tlsEndpointSocket = null; + + return Task.FromResult(0); } public void Dispose() { - Stop(); + StopAsync(); } private async Task AcceptDefaultEndpointConnectionsAsync(CancellationToken cancellationToken) diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs index 2258ab3..3c1300c 100644 --- a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Threading.Tasks; using MQTTnet.Core.Adapter; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Serializer; @@ -15,7 +16,7 @@ namespace MQTTnet.Implementations public event EventHandler ClientConnected; - public void Start(MqttServerOptions options) + public async Task StartAsync(MqttServerOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -26,7 +27,7 @@ namespace MQTTnet.Implementations if (options.DefaultEndpointOptions.IsEnabled) { _defaultEndpointSocket = new StreamSocketListener(); - _defaultEndpointSocket.BindServiceNameAsync(options.GetDefaultEndpointPort().ToString(), SocketProtectionLevel.PlainSocket).GetAwaiter().GetResult(); + await _defaultEndpointSocket.BindServiceNameAsync(options.GetDefaultEndpointPort().ToString(), SocketProtectionLevel.PlainSocket); _defaultEndpointSocket.ConnectionReceived += AcceptDefaultEndpointConnectionsAsync; } @@ -36,17 +37,19 @@ namespace MQTTnet.Implementations } } - public void Stop() + public Task StopAsync() { _isRunning = false; _defaultEndpointSocket?.Dispose(); _defaultEndpointSocket = null; + + return Task.FromResult(0); } public void Dispose() { - Stop(); + StopAsync(); } private void AcceptDefaultEndpointConnectionsAsync(StreamSocketListener sender, StreamSocketListenerConnectionReceivedEventArgs args) diff --git a/MQTTnet.Core/Adapter/IMqttServerAdapter.cs b/MQTTnet.Core/Adapter/IMqttServerAdapter.cs index 98d51bb..416139a 100644 --- a/MQTTnet.Core/Adapter/IMqttServerAdapter.cs +++ b/MQTTnet.Core/Adapter/IMqttServerAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Threading.Tasks; using MQTTnet.Core.Server; namespace MQTTnet.Core.Adapter @@ -7,8 +8,7 @@ namespace MQTTnet.Core.Adapter { event EventHandler ClientConnected; - void Start(MqttServerOptions options); - - void Stop(); + Task StartAsync(MqttServerOptions options); + Task StopAsync(); } } diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index 8eab8cb..600081c 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -36,6 +36,10 @@ namespace MQTTnet.Core.Adapter { throw; } + catch (OperationCanceledException) + { + throw; + } catch (MqttCommunicationTimedOutException) { throw; @@ -60,6 +64,10 @@ namespace MQTTnet.Core.Adapter { throw; } + catch (OperationCanceledException) + { + throw; + } catch (MqttCommunicationTimedOutException) { throw; @@ -87,7 +95,7 @@ namespace MQTTnet.Core.Adapter continue; } - MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), $"TX >>> {0} [Timeout={1}]", packet, timeout); + MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); var writeBuffer = PacketSerializer.Serialize(packet); await _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false); @@ -106,6 +114,10 @@ namespace MQTTnet.Core.Adapter { throw; } + catch (OperationCanceledException) + { + throw; + } catch (MqttCommunicationTimedOutException) { throw; @@ -156,6 +168,10 @@ namespace MQTTnet.Core.Adapter { throw; } + catch (OperationCanceledException) + { + throw; + } catch (MqttCommunicationTimedOutException) { throw; diff --git a/MQTTnet.Core/Adapter/MqttClientDisconnectedEventArgs.cs b/MQTTnet.Core/Adapter/MqttClientDisconnectedEventArgs.cs new file mode 100644 index 0000000..48d49c1 --- /dev/null +++ b/MQTTnet.Core/Adapter/MqttClientDisconnectedEventArgs.cs @@ -0,0 +1,17 @@ +using System; + +namespace MQTTnet.Core.Adapter +{ + public class MqttClientDisconnectedEventArgs : EventArgs + { + public MqttClientDisconnectedEventArgs(string identifier, IMqttCommunicationAdapter clientAdapter) + { + Identifier = identifier ?? throw new ArgumentNullException(nameof(identifier)); + ClientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter)); + } + + public string Identifier { get; } + + public IMqttCommunicationAdapter ClientAdapter { get; } + } +} diff --git a/MQTTnet.Core/Client/MqttClient.cs b/MQTTnet.Core/Client/MqttClient.cs index c03b1d0..aa42652 100644 --- a/MQTTnet.Core/Client/MqttClient.cs +++ b/MQTTnet.Core/Client/MqttClient.cs @@ -370,7 +370,7 @@ namespace MQTTnet.Core.Client await SendAndReceiveAsync(new MqttPingReqPacket()).ConfigureAwait(false); } } - catch (TaskCanceledException) + catch (OperationCanceledException) { } catch (MqttCommunicationException exception) @@ -413,7 +413,7 @@ namespace MQTTnet.Core.Client StartProcessReceivedPacket(packet, cancellationToken); } } - catch (TaskCanceledException) + catch (OperationCanceledException) { } catch (MqttCommunicationException exception) diff --git a/MQTTnet.Core/Packets/MqttSubscribePacket.cs b/MQTTnet.Core/Packets/MqttSubscribePacket.cs index 3fa379b..5cf1c26 100644 --- a/MQTTnet.Core/Packets/MqttSubscribePacket.cs +++ b/MQTTnet.Core/Packets/MqttSubscribePacket.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Core.Packets public override string ToString() { - var topicFiltersText = string.Join(",", TopicFilters.Select(f => $"{f.Topic}@{f.QualityOfServiceLevel}")); + var topicFiltersText = string.Join(",", TopicFilters.Select(f => f.Topic + "@" + f.QualityOfServiceLevel)); return nameof(MqttSubscribePacket) + ": [PacketIdentifier=" + PacketIdentifier + "] [TopicFilters=" + topicFiltersText + "]"; } } diff --git a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs index f43fc72..7f070d6 100644 --- a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs +++ b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs @@ -212,8 +212,7 @@ namespace MQTTnet.Core.Serializer default: { - throw new MqttProtocolViolationException( - $"Packet type ({(int)header.ControlPacketType}) not supported."); + throw new MqttProtocolViolationException($"Packet type ({(int)header.ControlPacketType}) not supported."); } } } diff --git a/MQTTnet.Core/Server/IMqttServer.cs b/MQTTnet.Core/Server/IMqttServer.cs index 7a75624..1800806 100644 --- a/MQTTnet.Core/Server/IMqttServer.cs +++ b/MQTTnet.Core/Server/IMqttServer.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Threading.Tasks; using MQTTnet.Core.Adapter; namespace MQTTnet.Core.Server @@ -12,7 +13,8 @@ namespace MQTTnet.Core.Server IList GetConnectedClients(); void InjectClient(string identifier, IMqttCommunicationAdapter adapter); void Publish(MqttApplicationMessage applicationMessage); - void Start(); - void Stop(); + + Task StartAsync(); + Task StopAsync(); } } \ No newline at end of file diff --git a/MQTTnet.Core/Server/IMqttServerStorage.cs b/MQTTnet.Core/Server/IMqttServerStorage.cs new file mode 100644 index 0000000..3cb518c --- /dev/null +++ b/MQTTnet.Core/Server/IMqttServerStorage.cs @@ -0,0 +1,12 @@ +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace MQTTnet.Core.Server +{ + public interface IMqttServerStorage + { + Task SaveRetainedMessagesAsync(IList messages); + + Task> LoadRetainedMessagesAsync(); + } +} diff --git a/MQTTnet.Core/Server/MqttClientMessageQueue.cs b/MQTTnet.Core/Server/MqttClientMessageQueue.cs deleted file mode 100644 index acff44f..0000000 --- a/MQTTnet.Core/Server/MqttClientMessageQueue.cs +++ /dev/null @@ -1,89 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Core.Adapter; -using MQTTnet.Core.Diagnostics; -using MQTTnet.Core.Exceptions; -using MQTTnet.Core.Packets; -using System.Linq; - -namespace MQTTnet.Core.Server -{ - public sealed class MqttClientMessageQueue - { - private readonly BlockingCollection _pendingPublishPackets = new BlockingCollection(); - - private readonly MqttServerOptions _options; - private CancellationTokenSource _cancellationTokenSource; - - public MqttClientMessageQueue(MqttServerOptions options) - { - _options = options ?? throw new ArgumentNullException(nameof(options)); - } - - public void Start(IMqttCommunicationAdapter adapter) - { - if (_cancellationTokenSource != null) - { - throw new InvalidOperationException($"{nameof(MqttClientMessageQueue)} already started."); - } - - if (adapter == null) throw new ArgumentNullException(nameof(adapter)); - _cancellationTokenSource = new CancellationTokenSource(); - - Task.Run(() => SendPendingPublishPacketsAsync(_cancellationTokenSource.Token, adapter), _cancellationTokenSource.Token); - } - - public void Stop() - { - _cancellationTokenSource?.Cancel(); - _cancellationTokenSource = null; - _pendingPublishPackets?.Dispose(); - } - - public void Enqueue(MqttPublishPacket publishPacket) - { - if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); - - _pendingPublishPackets.Add(publishPacket); - } - - private async Task SendPendingPublishPacketsAsync(CancellationToken cancellationToken, IMqttCommunicationAdapter adapter) - { - var consumable = _pendingPublishPackets.GetConsumingEnumerable(); - while (!cancellationToken.IsCancellationRequested) - { - if (_pendingPublishPackets.Count == 0) - { - await Task.Delay(TimeSpan.FromMilliseconds(5), cancellationToken).ConfigureAwait(false); - continue; - } - - var packets = consumable.Take(_pendingPublishPackets.Count).ToList(); - try - { - await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, packets).ConfigureAwait(false); - } - catch (MqttCommunicationException exception) - { - MqttTrace.Warning(nameof(MqttClientMessageQueue), exception, "Sending publish packet failed."); - foreach (var publishPacket in packets) - { - publishPacket.Dup = true; - _pendingPublishPackets.Add(publishPacket, cancellationToken); - } - } - catch (Exception exception) - { - MqttTrace.Error(nameof(MqttClientMessageQueue), exception, "Sending publish packet failed."); - foreach (var publishPacket in packets) - { - publishPacket.Dup = true; - _pendingPublishPackets.Add(publishPacket, cancellationToken); - } - } - } - } - } -} diff --git a/MQTTnet.Core/Server/MqttClientPendingMessagesQueue.cs b/MQTTnet.Core/Server/MqttClientPendingMessagesQueue.cs new file mode 100644 index 0000000..61a1b3b --- /dev/null +++ b/MQTTnet.Core/Server/MqttClientPendingMessagesQueue.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Core.Adapter; +using MQTTnet.Core.Diagnostics; +using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Packets; +using MQTTnet.Core.Protocol; + +namespace MQTTnet.Core.Server +{ + public sealed class MqttClientPendingMessagesQueue + { + private readonly BlockingCollection _pendingPublishPackets = new BlockingCollection(); + private readonly MqttClientSession _session; + private readonly MqttServerOptions _options; + + public MqttClientPendingMessagesQueue(MqttServerOptions options, MqttClientSession session) + { + _session = session ?? throw new ArgumentNullException(nameof(session)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + public void Start(IMqttCommunicationAdapter adapter, CancellationToken cancellationToken) + { + if (adapter == null) throw new ArgumentNullException(nameof(adapter)); + + Task.Run(() => SendPendingPublishPacketsAsync(adapter, cancellationToken), cancellationToken); + } + + public void Enqueue(MqttPublishPacket publishPacket) + { + if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); + + _pendingPublishPackets.Add(publishPacket); + } + + private async Task SendPendingPublishPacketsAsync(IMqttCommunicationAdapter adapter, CancellationToken cancellationToken) + { + try + { + while (!cancellationToken.IsCancellationRequested) + { + await SendPendingPublishPacketAsync(adapter, cancellationToken); + } + } + catch (OperationCanceledException) + { + } + catch (Exception exception) + { + MqttTrace.Error(nameof(MqttClientPendingMessagesQueue), exception, "Unhandled exception while sending pending publish packets."); + } + } + + private async Task SendPendingPublishPacketAsync(IMqttCommunicationAdapter adapter, CancellationToken cancellationToken) + { + var packet = _pendingPublishPackets.Take(cancellationToken); + + try + { + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, packet).ConfigureAwait(false); + } + catch (Exception exception) + { + if (exception is MqttCommunicationTimedOutException) + { + MqttTrace.Warning(nameof(MqttClientPendingMessagesQueue), exception, "Sending publish packet failed due to timeout."); + } + else if (exception is MqttCommunicationException) + { + MqttTrace.Warning(nameof(MqttClientPendingMessagesQueue), exception, "Sending publish packet failed due to communication exception."); + } + else + { + MqttTrace.Error(nameof(MqttClientPendingMessagesQueue), exception, "Sending publish packet failed."); + } + + if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) + { + packet.Dup = true; + _pendingPublishPackets.Add(packet, cancellationToken); + } + + _session.Stop(); + } + } + } +} diff --git a/MQTTnet.Core/Server/MqttClientRetainedMessagesManager.cs b/MQTTnet.Core/Server/MqttClientRetainedMessagesManager.cs new file mode 100644 index 0000000..f962a1a --- /dev/null +++ b/MQTTnet.Core/Server/MqttClientRetainedMessagesManager.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using MQTTnet.Core.Diagnostics; +using MQTTnet.Core.Internal; +using MQTTnet.Core.Packets; + +namespace MQTTnet.Core.Server +{ + public sealed class MqttClientRetainedMessagesManager + { + private readonly Dictionary _retainedMessages = new Dictionary(); + private readonly MqttServerOptions _options; + + public MqttClientRetainedMessagesManager(MqttServerOptions options) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + public async Task LoadMessagesAsync() + { + try + { + var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync(); + lock (_retainedMessages) + { + _retainedMessages.Clear(); + foreach (var retainedMessage in retainedMessages) + { + _retainedMessages[retainedMessage.Topic] = retainedMessage.ToPublishPacket(); + } + } + } + catch (Exception exception) + { + MqttTrace.Error(nameof(MqttClientRetainedMessagesManager), exception, "Unhandled exception while loading retained messages."); + } + } + + public async Task HandleMessageAsync(string clientId, MqttPublishPacket publishPacket) + { + if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); + + List allRetainedMessages; + lock (_retainedMessages) + { + if (publishPacket.Payload?.Any() == false) + { + _retainedMessages.Remove(publishPacket.Topic); + MqttTrace.Information(nameof(MqttClientRetainedMessagesManager), "Client '{0}' cleared retained message for topic '{1}'.", clientId, publishPacket.Topic); + } + else + { + _retainedMessages[publishPacket.Topic] = publishPacket; + MqttTrace.Information(nameof(MqttClientRetainedMessagesManager), "Client '{0}' updated retained message for topic '{1}'.", clientId, publishPacket.Topic); + } + + allRetainedMessages = new List(_retainedMessages.Values); + } + + try + { + // ReSharper disable once UseNullPropagation + if (_options.Storage != null) + { + await _options.Storage.SaveRetainedMessagesAsync(allRetainedMessages.Select(p => p.ToApplicationMessage()).ToList()); + } + } + catch (Exception exception) + { + MqttTrace.Error(nameof(MqttClientRetainedMessagesManager), exception, "Unhandled exception while saving retained messages."); + } + } + + public List GetMessages(MqttSubscribePacket subscribePacket) + { + var retainedMessages = new List(); + lock (_retainedMessages) + { + foreach (var retainedMessage in _retainedMessages.Values) + { + foreach (var topicFilter in subscribePacket.TopicFilters) + { + if (retainedMessage.QualityOfServiceLevel < topicFilter.QualityOfServiceLevel) + { + continue; + } + + if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic)) + { + continue; + } + + retainedMessages.Add(retainedMessage); + break; + } + } + } + + return retainedMessages; + } + } +} diff --git a/MQTTnet.Core/Server/MqttClientSession.cs b/MQTTnet.Core/Server/MqttClientSession.cs index 4836396..2115508 100644 --- a/MQTTnet.Core/Server/MqttClientSession.cs +++ b/MQTTnet.Core/Server/MqttClientSession.cs @@ -14,23 +14,22 @@ namespace MQTTnet.Core.Server public sealed class MqttClientSession : IDisposable { private readonly HashSet _unacknowledgedPublishPackets = new HashSet(); - + private readonly MqttClientSubscriptionsManager _subscriptionsManager = new MqttClientSubscriptionsManager(); - private readonly MqttClientMessageQueue _messageQueue; - private readonly Action _publishPacketReceivedCallback; + private readonly MqttClientSessionsManager _mqttClientSessionsManager; + private readonly MqttClientPendingMessagesQueue _pendingMessagesQueue; private readonly MqttServerOptions _options; - private CancellationTokenSource _cancellationTokenSource; private string _identifier; - private MqttApplicationMessage _willApplicationMessage; + private CancellationTokenSource _cancellationTokenSource; + private MqttApplicationMessage _willMessage; - public MqttClientSession(string clientId, MqttServerOptions options, Action publishPacketReceivedCallback) + public MqttClientSession(string clientId, MqttServerOptions options, MqttClientSessionsManager mqttClientSessionsManager) { ClientId = clientId; _options = options ?? throw new ArgumentNullException(nameof(options)); - _publishPacketReceivedCallback = publishPacketReceivedCallback ?? throw new ArgumentNullException(nameof(publishPacketReceivedCallback)); - - _messageQueue = new MqttClientMessageQueue(options); + _mqttClientSessionsManager = mqttClientSessionsManager ?? throw new ArgumentNullException(nameof(mqttClientSessionsManager)); + _pendingMessagesQueue = new MqttClientPendingMessagesQueue(options, this); } public string ClientId { get; } @@ -39,11 +38,11 @@ namespace MQTTnet.Core.Server public IMqttCommunicationAdapter Adapter { get; private set; } - public async Task RunAsync(string identifier, MqttApplicationMessage willApplicationMessage, IMqttCommunicationAdapter adapter) + public async Task RunAsync(string identifier, MqttApplicationMessage willMessage, IMqttCommunicationAdapter adapter) { if (adapter == null) throw new ArgumentNullException(nameof(adapter)); - _willApplicationMessage = willApplicationMessage; + _willMessage = willMessage; try { @@ -51,33 +50,36 @@ namespace MQTTnet.Core.Server Adapter = adapter; _cancellationTokenSource = new CancellationTokenSource(); - _messageQueue.Start(adapter); - while (!_cancellationTokenSource.IsCancellationRequested) - { - var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false); - await HandleIncomingPacketAsync(packet).ConfigureAwait(false); - } + _pendingMessagesQueue.Start(adapter, _cancellationTokenSource.Token); + await ReceivePacketsAsync(adapter, _cancellationTokenSource.Token); + } + catch (OperationCanceledException) + { } - catch (MqttCommunicationException) + catch (MqttCommunicationException exception) { + MqttTrace.Warning(nameof(MqttClientSession), exception, "Client '{0}': Communication exception while processing client packets.", _identifier); } catch (Exception exception) { MqttTrace.Error(nameof(MqttClientSession), exception, "Client '{0}': Unhandled exception while processing client packets.", _identifier); } - finally + } + + public void Stop() + { + if (_willMessage != null) { - if (willApplicationMessage != null) - { - _publishPacketReceivedCallback(this, _willApplicationMessage.ToPublishPacket()); - } + _mqttClientSessionsManager.DispatchPublishPacket(this, _willMessage.ToPublishPacket()); + } - _messageQueue.Stop(); - _cancellationTokenSource.Cancel(); - Adapter = null; + _cancellationTokenSource?.Cancel(false); + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; - MqttTrace.Information(nameof(MqttClientSession), "Client '{0}': Disconnected.", _identifier); - } + Adapter = null; + + MqttTrace.Information(nameof(MqttClientSession), "Client '{0}': Disconnected.", _identifier); } public void EnqueuePublishPacket(MqttPublishPacket publishPacket) @@ -89,7 +91,7 @@ namespace MQTTnet.Core.Server return; } - _messageQueue.Enqueue(publishPacket); + _pendingMessagesQueue.Enqueue(publishPacket); MqttTrace.Verbose(nameof(MqttClientSession), "Client '{0}': Enqueued pending publish packet.", _identifier); } @@ -99,68 +101,100 @@ namespace MQTTnet.Core.Server _cancellationTokenSource?.Dispose(); } - private Task HandleIncomingPacketAsync(MqttBasePacket packet) + private async Task ReceivePacketsAsync(IMqttCommunicationAdapter adapter, CancellationToken cancellationToken) { - if (packet is MqttSubscribePacket subscribePacket) + try { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Subscribe(subscribePacket)); + while (!cancellationToken.IsCancellationRequested) + { + var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); + await ProcessReceivedPacketAsync(packet).ConfigureAwait(false); + } } - - if (packet is MqttUnsubscribePacket unsubscribePacket) + catch (OperationCanceledException) { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Unsubscribe(unsubscribePacket)); } - - if (packet is MqttPublishPacket publishPacket) + catch (MqttCommunicationException exception) { - return HandleIncomingPublishPacketAsync(publishPacket); + MqttTrace.Warning(nameof(MqttClientSession), exception, "Client '{0}': Communication exception while processing client packets.", _identifier); + Stop(); } - - if (packet is MqttPubRelPacket pubRelPacket) + catch (Exception exception) { - return HandleIncomingPubRelPacketAsync(pubRelPacket); + MqttTrace.Error(nameof(MqttClientSession), exception, "Client '{0}': Unhandled exception while processing client packets.", _identifier); + Stop(); } + } - if (packet is MqttPubRecPacket pubRecPacket) + private async Task ProcessReceivedPacketAsync(MqttBasePacket packet) + { + if (packet is MqttSubscribePacket subscribePacket) { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, pubRecPacket.CreateResponse()); + await Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Subscribe(subscribePacket)); + EnqueueRetainedMessages(subscribePacket); } - - if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) + else if (packet is MqttUnsubscribePacket unsubscribePacket) + { + await Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Unsubscribe(unsubscribePacket)); + } + else if (packet is MqttPublishPacket publishPacket) + { + await HandleIncomingPublishPacketAsync(publishPacket); + } + else if (packet is MqttPubRelPacket pubRelPacket) + { + await HandleIncomingPubRelPacketAsync(pubRelPacket); + } + else if (packet is MqttPubRecPacket pubRecPacket) + { + await Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, pubRecPacket.CreateResponse()); + } + else if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) { // Discard message. - return Task.FromResult((object)null); } - - if (packet is MqttPingReqPacket) + else if (packet is MqttPingReqPacket) { - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPingRespPacket()); + await Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPingRespPacket()); } - - if (packet is MqttDisconnectPacket || packet is MqttConnectPacket) + else if (packet is MqttDisconnectPacket || packet is MqttConnectPacket) { _cancellationTokenSource.Cancel(); - return Task.FromResult((object)null); } + else + { + MqttTrace.Warning(nameof(MqttClientSession), "Client '{0}': Received not supported packet ({1}). Closing connection.", _identifier, packet); + _cancellationTokenSource.Cancel(); + } + } - MqttTrace.Warning(nameof(MqttClientSession), "Client '{0}': Received not supported packet ({1}). Closing connection.", _identifier, packet); - _cancellationTokenSource.Cancel(); - - return Task.FromResult((object)null); + private void EnqueueRetainedMessages(MqttSubscribePacket subscribePacket) + { + var retainedMessages = _mqttClientSessionsManager.RetainedMessagesManager.GetMessages(subscribePacket); + foreach (var publishPacket in retainedMessages) + { + EnqueuePublishPacket(publishPacket); + } } - private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) + private async Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) { + if (publishPacket.Retain) + { + await _mqttClientSessionsManager.RetainedMessagesManager.HandleMessageAsync(_identifier, publishPacket); + } + if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) { - _publishPacketReceivedCallback(this, publishPacket); - return Task.FromResult(0); + _mqttClientSessionsManager.DispatchPublishPacket(this, publishPacket); + return; } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) { - _publishPacketReceivedCallback(this, publishPacket); - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + _mqttClientSessionsManager.DispatchPublishPacket(this, publishPacket); + await Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + return; } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) @@ -171,9 +205,10 @@ namespace MQTTnet.Core.Server _unacknowledgedPublishPackets.Add(publishPacket.PacketIdentifier); } - _publishPacketReceivedCallback(this, publishPacket); + _mqttClientSessionsManager.DispatchPublishPacket(this, publishPacket); - return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + await Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + return; } throw new MqttCommunicationException("Received a not supported QoS level."); diff --git a/MQTTnet.Core/Server/MqttClientSessionsManager.cs b/MQTTnet.Core/Server/MqttClientSessionsManager.cs index 31688c9..71502ac 100644 --- a/MQTTnet.Core/Server/MqttClientSessionsManager.cs +++ b/MQTTnet.Core/Server/MqttClientSessionsManager.cs @@ -14,19 +14,24 @@ namespace MQTTnet.Core.Server { public sealed class MqttClientSessionsManager { - private readonly object _syncRoot = new object(); private readonly Dictionary _clientSessions = new Dictionary(); private readonly MqttServerOptions _options; public MqttClientSessionsManager(MqttServerOptions options) { _options = options ?? throw new ArgumentNullException(nameof(options)); + RetainedMessagesManager = new MqttClientRetainedMessagesManager(options); } public event EventHandler ApplicationMessageReceived; + public event EventHandler ClientDisconnected; + + public MqttClientRetainedMessagesManager RetainedMessagesManager { get; } + public async Task RunClientSessionAsync(MqttClientConnectedEventArgs eventArgs) { + var clientId = string.Empty; try { if (!(await eventArgs.ClientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false) is MqttConnectPacket connectPacket)) @@ -34,9 +39,11 @@ namespace MQTTnet.Core.Server throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1]."); } + clientId = connectPacket.ClientId; + // Switch to the required protocol version before sending any response. eventArgs.ClientAdapter.PacketSerializer.ProtocolVersion = connectPacket.ProtocolVersion; - + var connectReturnCode = ValidateConnection(connectPacket); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { @@ -65,12 +72,13 @@ namespace MQTTnet.Core.Server finally { await eventArgs.ClientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + ClientDisconnected?.Invoke(this, new MqttClientDisconnectedEventArgs(clientId, eventArgs.ClientAdapter)); } } public void Clear() { - lock (_syncRoot) + lock (_clientSessions) { _clientSessions.Clear(); } @@ -78,7 +86,7 @@ namespace MQTTnet.Core.Server public IList GetConnectedClients() { - lock (_syncRoot) + lock (_clientSessions) { return _clientSessions.Where(s => s.Value.IsConnected).Select(s => new ConnectedMqttClient { @@ -88,6 +96,27 @@ namespace MQTTnet.Core.Server } } + public void DispatchPublishPacket(MqttClientSession senderClientSession, MqttPublishPacket publishPacket) + { + try + { + var eventArgs = new MqttApplicationMessageReceivedEventArgs(senderClientSession?.ClientId, publishPacket.ToApplicationMessage()); + ApplicationMessageReceived?.Invoke(this, eventArgs); + } + catch (Exception exception) + { + MqttTrace.Error(nameof(MqttClientSessionsManager), exception, "Error while processing application message"); + } + + lock (_clientSessions) + { + foreach (var clientSession in _clientSessions.Values.ToList()) + { + clientSession.EnqueuePublishPacket(publishPacket); + } + } + } + private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) { if (_options.ConnectionValidator != null) @@ -100,10 +129,9 @@ namespace MQTTnet.Core.Server private GetOrCreateClientSessionResult GetOrCreateClientSession(MqttConnectPacket connectPacket) { - lock (_syncRoot) + lock (_clientSessions) { - MqttClientSession clientSession; - var isSessionPresent = _clientSessions.TryGetValue(connectPacket.ClientId, out clientSession); + var isSessionPresent = _clientSessions.TryGetValue(connectPacket.ClientId, out var clientSession); if (isSessionPresent) { if (connectPacket.CleanSession) @@ -124,7 +152,7 @@ namespace MQTTnet.Core.Server { isExistingSession = false; - clientSession = new MqttClientSession(connectPacket.ClientId, _options, DispatchPublishPacket); + clientSession = new MqttClientSession(connectPacket.ClientId, _options, this); _clientSessions[connectPacket.ClientId] = clientSession; MqttTrace.Verbose(nameof(MqttClientSessionsManager), "Created a new session for client '{0}'.", connectPacket.ClientId); @@ -133,26 +161,5 @@ namespace MQTTnet.Core.Server return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; } } - - public void DispatchPublishPacket(MqttClientSession senderClientSession, MqttPublishPacket publishPacket) - { - try - { - var eventArgs = new MqttApplicationMessageReceivedEventArgs(senderClientSession?.ClientId, publishPacket.ToApplicationMessage()); - ApplicationMessageReceived?.Invoke(this, eventArgs); - } - catch (Exception exception) - { - MqttTrace.Error(nameof(MqttClientSessionsManager), exception, "Error while processing application message"); - } - - lock (_syncRoot) - { - foreach (var clientSession in _clientSessions.Values.ToList()) - { - clientSession.EnqueuePublishPacket(publishPacket); - } - } - } } } \ No newline at end of file diff --git a/MQTTnet.Core/Server/MqttServer.cs b/MQTTnet.Core/Server/MqttServer.cs index b6f68ee..fd5a4dc 100644 --- a/MQTTnet.Core/Server/MqttServer.cs +++ b/MQTTnet.Core/Server/MqttServer.cs @@ -23,6 +23,7 @@ namespace MQTTnet.Core.Server _clientSessionsManager = new MqttClientSessionsManager(options); _clientSessionsManager.ApplicationMessageReceived += (s, e) => ApplicationMessageReceived?.Invoke(s, e); + _clientSessionsManager.ClientDisconnected += OnClientDisconnected; } public IList GetConnectedClients() @@ -31,7 +32,7 @@ namespace MQTTnet.Core.Server } public event EventHandler ClientConnected; - + public event EventHandler ClientDisconnected; public event EventHandler ApplicationMessageReceived; public void Publish(MqttApplicationMessage applicationMessage) @@ -44,28 +45,29 @@ namespace MQTTnet.Core.Server public void InjectClient(string identifier, IMqttCommunicationAdapter adapter) { if (adapter == null) throw new ArgumentNullException(nameof(adapter)); - if (_cancellationTokenSource == null) throw new InvalidOperationException("The MQTT server is not started."); OnClientConnected(this, new MqttClientConnectedEventArgs(identifier, adapter)); } - public void Start() + public async Task StartAsync() { if (_cancellationTokenSource != null) throw new InvalidOperationException("The MQTT server is already started."); _cancellationTokenSource = new CancellationTokenSource(); + await _clientSessionsManager.RetainedMessagesManager.LoadMessagesAsync(); + foreach (var adapter in _adapters) { adapter.ClientConnected += OnClientConnected; - adapter.Start(_options); + await adapter.StartAsync(_options); } MqttTrace.Information(nameof(MqttServer), "Started."); } - public void Stop() + public async Task StopAsync() { _cancellationTokenSource?.Cancel(false); _cancellationTokenSource?.Dispose(); @@ -74,7 +76,7 @@ namespace MQTTnet.Core.Server foreach (var adapter in _adapters) { adapter.ClientConnected -= OnClientConnected; - adapter.Stop(); + await adapter.StopAsync(); } _clientSessionsManager.Clear(); @@ -89,5 +91,11 @@ namespace MQTTnet.Core.Server Task.Run(() => _clientSessionsManager.RunClientSessionAsync(eventArgs), _cancellationTokenSource.Token); } + + private void OnClientDisconnected(object sender, MqttClientDisconnectedEventArgs eventArgs) + { + MqttTrace.Information(nameof(MqttServer), "Client '{0}': Disconnected.", eventArgs.Identifier); + ClientDisconnected?.Invoke(this, eventArgs); + } } } diff --git a/MQTTnet.Core/Server/MqttServerOptions.cs b/MQTTnet.Core/Server/MqttServerOptions.cs index 9de3627..c3673b9 100644 --- a/MQTTnet.Core/Server/MqttServerOptions.cs +++ b/MQTTnet.Core/Server/MqttServerOptions.cs @@ -15,5 +15,7 @@ namespace MQTTnet.Core.Server public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15); public Func ConnectionValidator { get; set; } + + public IMqttServerStorage Storage { get; set; } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index d3dc7f1..83978f3 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -49,7 +49,7 @@ namespace MQTTnet.Core.Tests public async Task MqttServer_WillMessage() { var s = new MqttServer(new MqttServerOptions(), new List { new TestMqttServerAdapter() }); - s.Start(); + s.StartAsync(); var willMessage = new MqttApplicationMessage("My/last/will", new byte[0], MqttQualityOfServiceLevel.AtMostOnce, false); var c1 = ConnectTestClient("c1", null, s); @@ -63,7 +63,7 @@ namespace MQTTnet.Core.Tests await Task.Delay(1000); - s.Stop(); + s.StopAsync(); Assert.AreEqual(1, receivedMessagesCount); } @@ -72,7 +72,7 @@ namespace MQTTnet.Core.Tests public async Task MqttServer_Unsubscribe() { var s = new MqttServer(new MqttServerOptions(), new List { new TestMqttServerAdapter() }); - s.Start(); + s.StartAsync(); var c1 = ConnectTestClient("c1", null, s); var c2 = ConnectTestClient("c2", null, s); @@ -97,7 +97,7 @@ namespace MQTTnet.Core.Tests await Task.Delay(500); Assert.AreEqual(1, receivedMessagesCount); - s.Stop(); + s.StopAsync(); await Task.Delay(500); Assert.AreEqual(1, receivedMessagesCount); @@ -107,7 +107,7 @@ namespace MQTTnet.Core.Tests public async Task MqttServer_Publish() { var s = new MqttServer(new MqttServerOptions(), new List { new TestMqttServerAdapter() }); - s.Start(); + s.StartAsync(); var c1 = ConnectTestClient("c1", null, s); @@ -120,7 +120,7 @@ namespace MQTTnet.Core.Tests s.Publish(message); await Task.Delay(500); - s.Stop(); + s.StopAsync(); Assert.AreEqual(1, receivedMessagesCount); } @@ -146,7 +146,7 @@ namespace MQTTnet.Core.Tests int expectedReceivedMessagesCount) { var s = new MqttServer(new MqttServerOptions(), new List { new TestMqttServerAdapter() }); - s.Start(); + s.StartAsync(); var c1 = ConnectTestClient("c1", null, s); var c2 = ConnectTestClient("c2", null, s); @@ -162,7 +162,7 @@ namespace MQTTnet.Core.Tests await Task.Delay(500); - s.Stop(); + s.StopAsync(); Assert.AreEqual(expectedReceivedMessagesCount, receivedMessagesCount); } diff --git a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs index 634d5b4..6c4c4a8 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Threading.Tasks; using MQTTnet.Core.Adapter; using MQTTnet.Core.Server; @@ -13,12 +14,14 @@ namespace MQTTnet.Core.Tests ClientConnected?.Invoke(this, eventArgs); } - public void Start(MqttServerOptions options) + public Task StartAsync(MqttServerOptions options) { + return Task.FromResult(0); } - public void Stop() - { + public Task StopAsync() + { + return Task.FromResult(0); } } } \ No newline at end of file diff --git a/Tests/MQTTnet.TestApp.NetCore/Program.cs b/Tests/MQTTnet.TestApp.NetCore/Program.cs index e5470d9..e4ccf95 100644 --- a/Tests/MQTTnet.TestApp.NetCore/Program.cs +++ b/Tests/MQTTnet.TestApp.NetCore/Program.cs @@ -153,12 +153,12 @@ namespace MQTTnet.TestApp.NetCore }; var mqttServer = new MqttServerFactory().CreateMqttServer(options); - mqttServer.Start(); + mqttServer.StartAsync(); Console.WriteLine("Press any key to exit."); Console.ReadLine(); - mqttServer.Stop(); + mqttServer.StopAsync(); } catch (Exception e) { diff --git a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs index 41f24b2..e78a3f9 100644 --- a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs @@ -167,7 +167,7 @@ namespace MQTTnet.TestApp.NetFramework }); } - private static void RunServerAsync() + private static async Task RunServerAsync() { try { @@ -201,12 +201,12 @@ namespace MQTTnet.TestApp.NetFramework stopwatch.Restart(); } }; - mqttServer.Start(); + await mqttServer.StartAsync(); Console.WriteLine("Press any key to exit."); Console.ReadLine(); - mqttServer.Stop(); + await mqttServer.StopAsync(); } catch (Exception e) { diff --git a/Tests/MQTTnet.TestApp.NetFramework/Program.cs b/Tests/MQTTnet.TestApp.NetFramework/Program.cs index 59647b2..b5a7b24 100644 --- a/Tests/MQTTnet.TestApp.NetFramework/Program.cs +++ b/Tests/MQTTnet.TestApp.NetFramework/Program.cs @@ -121,7 +121,7 @@ namespace MQTTnet.TestApp.NetFramework } } - private static void RunServerAsync(string[] arguments) + private static async Task RunServerAsync(string[] arguments) { MqttTrace.TraceMessagePublished += (s, e) => { @@ -151,12 +151,12 @@ namespace MQTTnet.TestApp.NetFramework }; var mqttServer = new MqttServerFactory().CreateMqttServer(options); - mqttServer.Start(); + await mqttServer.StartAsync(); Console.WriteLine("Press any key to exit."); Console.ReadLine(); - mqttServer.Stop(); + await mqttServer.StopAsync(); } catch (Exception e) { @@ -166,4 +166,4 @@ namespace MQTTnet.TestApp.NetFramework Console.ReadLine(); } } -} +} \ No newline at end of file