From 8de87c8772a94c249b1a90fcd04ff0711a73a5e3 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Thu, 16 Nov 2017 00:28:56 +0100 Subject: [PATCH] Wrap COM exceptions, close client connection when server stops, process retained messages from server --- Build/MQTTnet.nuspec | 7 +- .../Implementations/MqttWebSocketChannel.cs | 15 +- .../Implementations/WebSocketStream.cs | 2 +- Frameworks/MQTTnet.NetStandard/MqttFactory.cs | 4 +- .../MqttChannelCommunicationAdapter.cs | 225 ++++++++---------- MQTTnet.Core/Server/IMqttServer.cs | 3 +- .../Server/MqttClientPendingMessagesQueue.cs | 2 +- MQTTnet.Core/Server/MqttClientSession.cs | 53 ++--- .../Server/MqttClientSessionsManager.cs | 85 +++++-- MQTTnet.Core/Server/MqttServer.cs | 30 +-- 10 files changed, 201 insertions(+), 225 deletions(-) diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 8bdd4a1..c1cd371 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -10,8 +10,11 @@ https://raw.githubusercontent.com/chkr1011/MQTTnet/master/Images/Logo_128x128.png false MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker). - * [Client] Fixed WebSocket sub protocol negotiation (Thanks to @JanEggers) - + * [Core] Fixed library reference issues for .NET 4.6 (Thanks to @JanEggers). +* [Core] Several COM exceptions are now wrapped properly resulting in less warnings in the trace. +* [Client] Fixed WebSocket sub protocol negotiation for ASP.NET Core 2 servers (Thanks to @JanEggers). +* [Server] Client connections are now closed when the server is stopped (Thanks to @zhudanfei). +* [Server] Published messages from the server are now retained (if set) (Thanks to @ChristianRiedl). BREAKING CHANGE! Copyright Christian Kratky 2016-2017 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs index 1314a1c..c7cd3d0 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs @@ -19,9 +19,8 @@ namespace MQTTnet.Implementations _options = options ?? throw new ArgumentNullException(nameof(options)); } - public Stream SendStream => RawReceiveStream; - public Stream ReceiveStream => RawReceiveStream; - public Stream RawReceiveStream { get; private set; } + public Stream SendStream { get; private set; } + public Stream ReceiveStream { get; private set; } public async Task ConnectAsync() { @@ -32,7 +31,7 @@ namespace MQTTnet.Implementations } _webSocket = new ClientWebSocket(); - + if (_options.RequestHeaders != null) { foreach (var requestHeader in _options.RequestHeaders) @@ -64,13 +63,13 @@ namespace MQTTnet.Implementations } await _webSocket.ConnectAsync(new Uri(uri), CancellationToken.None).ConfigureAwait(false); - RawReceiveStream = new WebSocketStream(_webSocket); + + SendStream = new WebSocketStream(_webSocket); + ReceiveStream = SendStream; } public async Task DisconnectAsync() { - RawReceiveStream = null; - if (_webSocket == null) { return; @@ -80,6 +79,8 @@ namespace MQTTnet.Implementations { await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); } + + _webSocket = null; } public void Dispose() diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs b/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs index c27e480..fb2d812 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs @@ -43,7 +43,7 @@ namespace MQTTnet.Implementations var response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, currentOffset, count), cancellationToken).ConfigureAwait(false); currentOffset += response.Count; count -= response.Count; - + if (response.MessageType == WebSocketMessageType.Close) { await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); diff --git a/Frameworks/MQTTnet.NetStandard/MqttFactory.cs b/Frameworks/MQTTnet.NetStandard/MqttFactory.cs index 1729350..c8b7968 100644 --- a/Frameworks/MQTTnet.NetStandard/MqttFactory.cs +++ b/Frameworks/MQTTnet.NetStandard/MqttFactory.cs @@ -85,9 +85,7 @@ namespace MQTTnet clientSessionsManager, _serviceProvider.GetRequiredService(), _serviceProvider.GetRequiredService>(), - _serviceProvider.GetRequiredService>(), - _serviceProvider.GetRequiredService() - ); + _serviceProvider.GetRequiredService>()); } public IMqttClient CreateMqttClient() diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index 3557302..c859ec0 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Channel; @@ -14,10 +15,12 @@ namespace MQTTnet.Core.Adapter { public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter { + private const uint ErrorOperationAborted = 0x800703E3; + private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly ILogger _logger; private readonly IMqttCommunicationChannel _channel; - + public MqttChannelCommunicationAdapter(IMqttCommunicationChannel channel, IMqttPacketSerializer serializer, ILogger logger) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); @@ -29,166 +32,87 @@ namespace MQTTnet.Core.Adapter public async Task ConnectAsync(TimeSpan timeout) { - try - { - await _channel.ConnectAsync().TimeoutAfter(timeout).ConfigureAwait(false); - } - catch (TaskCanceledException) - { - throw; - } - catch (OperationCanceledException) - { - throw; - } - catch (MqttCommunicationTimedOutException) - { - throw; - } - catch (MqttCommunicationException) - { - throw; - } - catch (Exception exception) - { - throw new MqttCommunicationException(exception); - } + await ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync().TimeoutAfter(timeout)); } public async Task DisconnectAsync(TimeSpan timeout) { - try - { - await _channel.DisconnectAsync().TimeoutAfter(timeout).ConfigureAwait(false); - } - catch (TaskCanceledException) - { - throw; - } - catch (OperationCanceledException) - { - throw; - } - catch (MqttCommunicationTimedOutException) - { - throw; - } - catch (MqttCommunicationException) - { - throw; - } - catch (Exception exception) - { - throw new MqttCommunicationException(exception); - } + await ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); } public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) { - try + await ExecuteAndWrapExceptionAsync(async () => { await _semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - - foreach (var packet in packets) + try { - if (packet == null) + foreach (var packet in packets) { - continue; - } + if (packet == null) + { + continue; + } - _logger.LogInformation("TX >>> {0} [Timeout={1}]", packet, timeout); + _logger.LogInformation("TX >>> {0} [Timeout={1}]", packet, timeout); - var writeBuffer = PacketSerializer.Serialize(packet); - await _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false); - } + var writeBuffer = PacketSerializer.Serialize(packet); + await _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false); + } - if (timeout > TimeSpan.Zero) - { - await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); + if (timeout > TimeSpan.Zero) + { + await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); + } + else + { + await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); + } } - else + finally { - await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); + _semaphore.Release(); } - } - catch (TaskCanceledException) - { - throw; - } - catch (OperationCanceledException) - { - throw; - } - catch (MqttCommunicationTimedOutException) - { - throw; - } - catch (MqttCommunicationException) - { - throw; - } - catch (Exception exception) - { - throw new MqttCommunicationException(exception); - } - finally - { - _semaphore.Release(); - } + }); } public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { - ReceivedMqttPacket receivedMqttPacket = null; - try + MqttBasePacket packet = null; + await ExecuteAndWrapExceptionAsync(async () => { - if (timeout > TimeSpan.Zero) - { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); - } - else + ReceivedMqttPacket receivedMqttPacket = null; + try { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); - } + if (timeout > TimeSpan.Zero) + { + receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); + } + else + { + receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); + } - if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) - { - throw new TaskCanceledException(); - } + if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException(); + } + + packet = PacketSerializer.Deserialize(receivedMqttPacket); + if (packet == null) + { + throw new MqttProtocolViolationException("Received malformed packet."); + } - var packet = PacketSerializer.Deserialize(receivedMqttPacket); - if (packet == null) + _logger.LogInformation("RX <<< {0}", packet); + } + finally { - throw new MqttProtocolViolationException("Received malformed packet."); + receivedMqttPacket?.Dispose(); } + }); - _logger.LogInformation("RX <<< {0}", packet); - return packet; - } - catch (TaskCanceledException) - { - throw; - } - catch (OperationCanceledException) - { - throw; - } - catch (MqttCommunicationTimedOutException) - { - throw; - } - catch (MqttCommunicationException) - { - throw; - } - catch (Exception exception) - { - throw new MqttCommunicationException(exception); - } - finally - { - receivedMqttPacket?.Dispose(); - } + return packet; } private static async Task ReceiveAsync(Stream stream, CancellationToken cancellationToken) @@ -215,5 +139,42 @@ namespace MQTTnet.Core.Adapter return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length)); } + + private static async Task ExecuteAndWrapExceptionAsync(Func action) + { + try + { + await action().ConfigureAwait(false); + } + catch (TaskCanceledException) + { + throw; + } + catch (OperationCanceledException) + { + throw; + } + catch (MqttCommunicationTimedOutException) + { + throw; + } + catch (MqttCommunicationException) + { + throw; + } + catch (COMException comException) + { + if ((uint)comException.HResult == ErrorOperationAborted) + { + throw new OperationCanceledException(); + } + + throw new MqttCommunicationException(comException); + } + catch (Exception exception) + { + throw new MqttCommunicationException(exception); + } + } } } diff --git a/MQTTnet.Core/Server/IMqttServer.cs b/MQTTnet.Core/Server/IMqttServer.cs index 13c614f..8d86077 100644 --- a/MQTTnet.Core/Server/IMqttServer.cs +++ b/MQTTnet.Core/Server/IMqttServer.cs @@ -10,8 +10,7 @@ namespace MQTTnet.Core.Server event EventHandler ClientDisconnected; event EventHandler Started; - IList GetConnectedClients(); - void Publish(IEnumerable applicationMessages); + Task> GetConnectedClientsAsync(); Task StartAsync(); Task StopAsync(); diff --git a/MQTTnet.Core/Server/MqttClientPendingMessagesQueue.cs b/MQTTnet.Core/Server/MqttClientPendingMessagesQueue.cs index 1e0fc76..16c9094 100644 --- a/MQTTnet.Core/Server/MqttClientPendingMessagesQueue.cs +++ b/MQTTnet.Core/Server/MqttClientPendingMessagesQueue.cs @@ -96,7 +96,7 @@ namespace MQTTnet.Core.Server _pendingPublishPackets.Add(packet, CancellationToken.None); } - _session.Stop(); + await _session.StopAsync(); } } } diff --git a/MQTTnet.Core/Server/MqttClientSession.cs b/MQTTnet.Core/Server/MqttClientSession.cs index 93b422f..42451dd 100644 --- a/MQTTnet.Core/Server/MqttClientSession.cs +++ b/MQTTnet.Core/Server/MqttClientSession.cs @@ -17,7 +17,6 @@ namespace MQTTnet.Core.Server { private readonly HashSet _unacknowledgedPublishPackets = new HashSet(); - private readonly IMqttClientRetainedMessageManager _clientRetainedMessageManager; private readonly MqttClientSubscriptionsManager _subscriptionsManager; private readonly MqttClientSessionsManager _sessionsManager; private readonly MqttClientPendingMessagesQueue _pendingMessagesQueue; @@ -34,10 +33,8 @@ namespace MQTTnet.Core.Server MqttClientSessionsManager sessionsManager, MqttClientSubscriptionsManager subscriptionsManager, ILogger logger, - ILogger messageQueueLogger, - IMqttClientRetainedMessageManager clientRetainedMessageManager) + ILogger messageQueueLogger) { - _clientRetainedMessageManager = clientRetainedMessageManager ?? throw new ArgumentNullException(nameof(clientRetainedMessageManager)); _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); _subscriptionsManager = subscriptionsManager ?? throw new ArgumentNullException(nameof(subscriptionsManager)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); @@ -82,7 +79,7 @@ namespace MQTTnet.Core.Server } } - public void Stop() + public async Task StopAsync() { try { @@ -90,17 +87,21 @@ namespace MQTTnet.Core.Server _cancellationTokenSource?.Dispose(); _cancellationTokenSource = null; - _adapter = null; + if (_adapter != null) + { + await _adapter.DisconnectAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + _adapter = null; + } - _logger.LogInformation("Client '{0}': Disconnected.", ClientId); + _logger.LogInformation("Client '{0}': Session stopped.", ClientId); } finally { - var willMessage = _willMessage; + var willMessage = _willMessage; if (willMessage != null) { _willMessage = null; //clear willmessage so it is send just once - _sessionsManager.DispatchApplicationMessage(this, willMessage); + await _sessionsManager.DispatchApplicationMessageAsync(this, willMessage); } } } @@ -133,12 +134,12 @@ namespace MQTTnet.Core.Server catch (MqttCommunicationException exception) { _logger.LogWarning(new EventId(), exception, "Client '{0}': Communication exception while processing client packets.", ClientId); - Stop(); + await StopAsync(); } catch (Exception exception) { _logger.LogError(new EventId(), exception, "Client '{0}': Unhandled exception while processing client packets.", ClientId); - Stop(); + await StopAsync(); } } @@ -182,14 +183,11 @@ namespace MQTTnet.Core.Server if (packet is MqttDisconnectPacket || packet is MqttConnectPacket) { - Stop(); - return Task.FromResult(0); + return StopAsync(); } _logger.LogWarning("Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); - Stop(); - - return Task.FromResult(0); + return StopAsync(); } private async Task HandleIncomingSubscribePacketAsync(IMqttCommunicationAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) @@ -202,13 +200,13 @@ namespace MQTTnet.Core.Server if (subscribeResult.CloseConnection) { await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttDisconnectPacket()).ConfigureAwait(false); - Stop(); + await StopAsync(); } } private async Task EnqueueSubscribedRetainedMessagesAsync(MqttSubscribePacket subscribePacket) { - var retainedMessages = await _clientRetainedMessageManager.GetSubscribedMessagesAsync(subscribePacket).ConfigureAwait(false); + var retainedMessages = await _sessionsManager.GetRetainedMessagesAsync(subscribePacket).ConfigureAwait(false); foreach (var publishPacket in retainedMessages) { EnqueuePublishPacket(publishPacket.ToPublishPacket()); @@ -219,29 +217,16 @@ namespace MQTTnet.Core.Server { var applicationMessage = publishPacket.ToApplicationMessage(); - var interceptorContext = new MqttApplicationMessageInterceptorContext - { - ApplicationMessage = applicationMessage - }; - - _options.ApplicationMessageInterceptor?.Invoke(interceptorContext); - applicationMessage = interceptorContext.ApplicationMessage; - - if (applicationMessage.Retain) - { - await _clientRetainedMessageManager.HandleMessageAsync(ClientId, applicationMessage).ConfigureAwait(false); - } - switch (applicationMessage.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: { - _sessionsManager.DispatchApplicationMessage(this, applicationMessage); + await _sessionsManager.DispatchApplicationMessageAsync(this, applicationMessage); return; } case MqttQualityOfServiceLevel.AtLeastOnce: { - _sessionsManager.DispatchApplicationMessage(this, applicationMessage); + await _sessionsManager.DispatchApplicationMessageAsync(this, applicationMessage); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); @@ -256,7 +241,7 @@ namespace MQTTnet.Core.Server _unacknowledgedPublishPackets.Add(publishPacket.PacketIdentifier); } - _sessionsManager.DispatchApplicationMessage(this, applicationMessage); + await _sessionsManager.DispatchApplicationMessageAsync(this, applicationMessage); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }).ConfigureAwait(false); diff --git a/MQTTnet.Core/Server/MqttClientSessionsManager.cs b/MQTTnet.Core/Server/MqttClientSessionsManager.cs index 48fd33e..80722ab 100644 --- a/MQTTnet.Core/Server/MqttClientSessionsManager.cs +++ b/MQTTnet.Core/Server/MqttClientSessionsManager.cs @@ -16,19 +16,24 @@ namespace MQTTnet.Core.Server { public sealed class MqttClientSessionsManager { - private readonly Dictionary _clientSessions = new Dictionary(); + private readonly Dictionary _sessions = new Dictionary(); + private readonly SemaphoreSlim _sessionsSemaphore = new SemaphoreSlim(1, 1); + + private readonly MqttServerOptions _options; private readonly ILogger _logger; private readonly IMqttClientSesssionFactory _clientSesssionFactory; - private readonly MqttServerOptions _options; + private readonly IMqttClientRetainedMessageManager _clientRetainedMessageManager; public MqttClientSessionsManager( - IOptions options, + IOptions options, ILogger logger, - IMqttClientSesssionFactory clientSesssionFactory) + IMqttClientSesssionFactory clientSesssionFactory, + IMqttClientRetainedMessageManager clientRetainedMessageManager) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _options = options.Value ?? throw new ArgumentNullException(nameof(options)); _clientSesssionFactory = clientSesssionFactory ?? throw new ArgumentNullException(nameof(clientSesssionFactory)); + _clientRetainedMessageManager = clientRetainedMessageManager ?? throw new ArgumentNullException(nameof(clientRetainedMessageManager)); } public event EventHandler ClientConnected; @@ -61,7 +66,7 @@ namespace MQTTnet.Core.Server return; } - var clientSession = GetOrCreateClientSession(connectPacket); + var clientSession = await GetOrCreateClientSessionAsync(connectPacket); await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttConnAckPacket { @@ -92,7 +97,7 @@ namespace MQTTnet.Core.Server } catch (Exception) { - //ignored + // ignored } ClientDisconnected?.Invoke(this, new MqttClientDisconnectedEventArgs(new ConnectedMqttClient @@ -103,30 +108,58 @@ namespace MQTTnet.Core.Server } } - public void Clear() + public async Task StopAsync() { - lock (_clientSessions) + await _sessionsSemaphore.WaitAsync().ConfigureAwait(false); + try + { + foreach (var session in _sessions) + { + await session.Value.StopAsync(); + } + + _sessions.Clear(); + } + finally { - _clientSessions.Clear(); + _sessionsSemaphore.Release(); } } - public IList GetConnectedClients() + public async Task> GetConnectedClientsAsync() { - lock (_clientSessions) + await _sessionsSemaphore.WaitAsync().ConfigureAwait(false); + try { - return _clientSessions.Where(s => s.Value.IsConnected).Select(s => new ConnectedMqttClient + return _sessions.Where(s => s.Value.IsConnected).Select(s => new ConnectedMqttClient { ClientId = s.Value.ClientId, ProtocolVersion = s.Value.ProtocolVersion ?? MqttProtocolVersion.V311 }).ToList(); } + finally + { + _sessionsSemaphore.Release(); + } } - public void DispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) + public async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) { try { + var interceptorContext = new MqttApplicationMessageInterceptorContext + { + ApplicationMessage = applicationMessage + }; + + _options.ApplicationMessageInterceptor?.Invoke(interceptorContext); + applicationMessage = interceptorContext.ApplicationMessage; + + if (applicationMessage.Retain) + { + await _clientRetainedMessageManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false); + } + var eventArgs = new MqttApplicationMessageReceivedEventArgs(senderClientSession?.ClientId, applicationMessage); ApplicationMessageReceived?.Invoke(this, eventArgs); } @@ -135,15 +168,20 @@ namespace MQTTnet.Core.Server _logger.LogError(new EventId(), exception, "Error while processing application message"); } - lock (_clientSessions) + lock (_sessions) { - foreach (var clientSession in _clientSessions.Values.ToList()) + foreach (var clientSession in _sessions.Values.ToList()) { clientSession.EnqueuePublishPacket(applicationMessage.ToPublishPacket()); } } } + public Task> GetRetainedMessagesAsync(MqttSubscribePacket subscribePacket) + { + return _clientRetainedMessageManager.GetSubscribedMessagesAsync(subscribePacket); + } + private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) { if (_options.ConnectionValidator != null) @@ -154,17 +192,18 @@ namespace MQTTnet.Core.Server return MqttConnectReturnCode.ConnectionAccepted; } - private GetOrCreateClientSessionResult GetOrCreateClientSession(MqttConnectPacket connectPacket) + private async Task GetOrCreateClientSessionAsync(MqttConnectPacket connectPacket) { - lock (_clientSessions) + await _sessionsSemaphore.WaitAsync().ConfigureAwait(false); + try { - var isSessionPresent = _clientSessions.TryGetValue(connectPacket.ClientId, out var clientSession); + var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); if (isSessionPresent) { if (connectPacket.CleanSession) { - _clientSessions.Remove(connectPacket.ClientId); - clientSession.Stop(); + _sessions.Remove(connectPacket.ClientId); + await clientSession.StopAsync(); clientSession = null; _logger.LogTrace("Stopped existing session of client '{0}'.", connectPacket.ClientId); @@ -181,13 +220,17 @@ namespace MQTTnet.Core.Server isExistingSession = false; clientSession = _clientSesssionFactory.CreateClientSession(connectPacket.ClientId, this); - _clientSessions[connectPacket.ClientId] = clientSession; + _sessions[connectPacket.ClientId] = clientSession; _logger.LogTrace("Created a new session for client '{0}'.", connectPacket.ClientId); } return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; } + finally + { + _sessionsSemaphore.Release(); + } } } } \ No newline at end of file diff --git a/MQTTnet.Core/Server/MqttServer.cs b/MQTTnet.Core/Server/MqttServer.cs index bd06a3a..c671907 100644 --- a/MQTTnet.Core/Server/MqttServer.cs +++ b/MQTTnet.Core/Server/MqttServer.cs @@ -20,12 +20,11 @@ namespace MQTTnet.Core.Server private CancellationTokenSource _cancellationTokenSource; public MqttServer( - IOptions options, + IOptions options, IEnumerable adapters, ILogger logger, MqttClientSessionsManager clientSessionsManager, - IMqttClientRetainedMessageManager clientRetainedMessageManager - ) + IMqttClientRetainedMessageManager clientRetainedMessageManager) { _options = options.Value ?? throw new ArgumentNullException(nameof(options)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); @@ -35,7 +34,7 @@ namespace MQTTnet.Core.Server if (adapters == null) { throw new ArgumentNullException(nameof(adapters)); - } + } _adapters = adapters.ToList(); @@ -44,9 +43,9 @@ namespace MQTTnet.Core.Server _clientSessionsManager.ClientDisconnected += OnClientDisconnected; } - public IList GetConnectedClients() + public Task> GetConnectedClientsAsync() { - return _clientSessionsManager.GetConnectedClients(); + return _clientSessionsManager.GetConnectedClientsAsync(); } public event EventHandler Started; @@ -54,7 +53,7 @@ namespace MQTTnet.Core.Server public event EventHandler ClientDisconnected; public event EventHandler ApplicationMessageReceived; - public void Publish(IEnumerable applicationMessages) + public async Task PublishAsync(IEnumerable applicationMessages) { if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); @@ -65,23 +64,10 @@ namespace MQTTnet.Core.Server foreach (var applicationMessage in applicationMessages) { - var interceptorContext = new MqttApplicationMessageInterceptorContext - { - ApplicationMessage = applicationMessage - }; - - _options.ApplicationMessageInterceptor?.Invoke(interceptorContext); - - _clientSessionsManager.DispatchApplicationMessage(null, interceptorContext.ApplicationMessage); + await _clientSessionsManager.DispatchApplicationMessageAsync(null, applicationMessage); } } - public Task PublishAsync(IEnumerable applicationMessages) - { - Publish(applicationMessages); - return Task.FromResult(0); - } - public async Task StartAsync() { if (_cancellationTokenSource != null) throw new InvalidOperationException("The MQTT server is already started."); @@ -113,7 +99,7 @@ namespace MQTTnet.Core.Server await adapter.StopAsync(); } - _clientSessionsManager.Clear(); + await _clientSessionsManager.StopAsync(); _logger.LogInformation("Stopped."); }