diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index fd4b820..be29f5d 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -66,24 +66,26 @@ namespace MQTTnet.Client _packetIdentifierProvider.Reset(); _packetDispatcher.Reset(); - _cancellationTokenSource = new CancellationTokenSource(); + var cancellationTokenSource = new CancellationTokenSource(); + _cancellationTokenSource = cancellationTokenSource; + _disconnectGate = 0; _adapter = _adapterFactory.CreateClientAdapter(options, _logger); _logger.Verbose($"Trying to connect with server ({Options.ChannelOptions})."); - await _adapter.ConnectAsync(Options.CommunicationTimeout, _cancellationTokenSource.Token).ConfigureAwait(false); + await _adapter.ConnectAsync(Options.CommunicationTimeout, cancellationTokenSource.Token).ConfigureAwait(false); _logger.Verbose("Connection with server established."); - StartReceivingPackets(_cancellationTokenSource.Token); + StartReceivingPackets(cancellationTokenSource.Token); - var connectResult = await AuthenticateAsync(options.WillMessage, _cancellationTokenSource.Token).ConfigureAwait(false); + var connectResult = await AuthenticateAsync(options.WillMessage, cancellationTokenSource.Token).ConfigureAwait(false); _logger.Verbose("MQTT connection with server established."); _sendTracker.Restart(); if (Options.KeepAlivePeriod != TimeSpan.Zero) { - StartSendingKeepAliveMessages(_cancellationTokenSource.Token); + StartSendingKeepAliveMessages(cancellationTokenSource.Token); } IsConnected = true; @@ -112,7 +114,7 @@ namespace MQTTnet.Client { _cleanDisconnectInitiated = true; - if (IsConnected && !_cancellationTokenSource.IsCancellationRequested) + if (IsConnected && _cancellationTokenSource?.IsCancellationRequested == false) { var disconnectPacket = CreateDisconnectPacket(options); await SendAsync(disconnectPacket, _cancellationTokenSource.Token).ConfigureAwait(false); diff --git a/Source/MQTTnet/Client/MqttClientExtensions.cs b/Source/MQTTnet/Client/MqttClientExtensions.cs index 00482ea..b399ef6 100644 --- a/Source/MQTTnet/Client/MqttClientExtensions.cs +++ b/Source/MQTTnet/Client/MqttClientExtensions.cs @@ -1,5 +1,4 @@ using System; -using System.Linq; using System.Threading.Tasks; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; @@ -21,7 +20,7 @@ namespace MQTTnet.Client if (client == null) throw new ArgumentNullException(nameof(client)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return client.SubscribeAsync(topicFilters.ToList()); + return client.SubscribeAsync(topicFilters); } public static Task SubscribeAsync(this IMqttClient client, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel) @@ -45,7 +44,7 @@ namespace MQTTnet.Client if (client == null) throw new ArgumentNullException(nameof(client)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return client.UnsubscribeAsync(topicFilters.ToList()); + return client.UnsubscribeAsync(topicFilters); } } } \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index e389ef0..802463a 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -93,7 +93,7 @@ namespace MQTTnet.Implementations using (cancellationToken.Register(() => _socket.Dispose())) { await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); - await _stream.FlushAsync(cancellationToken); + //await _stream.FlushAsync(cancellationToken); } } diff --git a/Source/MQTTnet/Server/IMqttClientSessionStatus.cs b/Source/MQTTnet/Server/IMqttClientSessionStatus.cs index 38b52e2..d46db69 100644 --- a/Source/MQTTnet/Server/IMqttClientSessionStatus.cs +++ b/Source/MQTTnet/Server/IMqttClientSessionStatus.cs @@ -18,7 +18,11 @@ namespace MQTTnet.Server TimeSpan LastNonKeepAlivePacketReceived { get; } - int PendingApplicationMessagesCount { get; } + long PendingApplicationMessagesCount { get; } + + long ReceivedApplicationMessagesCount { get; } + + long SentApplicationMessagesCount { get; } Task DisconnectAsync(); diff --git a/Source/MQTTnet/Server/IMqttServer.cs b/Source/MQTTnet/Server/IMqttServer.cs index e42c903..cccda2b 100644 --- a/Source/MQTTnet/Server/IMqttServer.cs +++ b/Source/MQTTnet/Server/IMqttServer.cs @@ -21,8 +21,8 @@ namespace MQTTnet.Server IList GetRetainedMessages(); Task ClearRetainedMessagesAsync(); - Task SubscribeAsync(string clientId, IList topicFilters); - Task UnsubscribeAsync(string clientId, IList topicFilters); + Task SubscribeAsync(string clientId, IEnumerable topicFilters); + Task UnsubscribeAsync(string clientId, IEnumerable topicFilters); Task StartAsync(IMqttServerOptions options); Task StopAsync(); diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 0457fef..688d094 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -18,7 +19,7 @@ namespace MQTTnet.Server private readonly MqttRetainedMessagesManager _retainedMessagesManager; private readonly MqttServerEventDispatcher _eventDispatcher; private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; - private readonly MqttClientPendingPacketsQueue _pendingPacketsQueue; + private readonly MqttClientSessionPendingMessagesQueue _pendingMessagesQueue; private readonly MqttClientSubscriptionsManager _subscriptionsManager; private readonly MqttClientSessionsManager _sessionsManager; @@ -31,6 +32,9 @@ namespace MQTTnet.Server private Task _workerTask; private IMqttChannelAdapter _channelAdapter; + private long _receivedMessagesCount; + private bool _isCleanSession = true; + public MqttClientSession( string clientId, IMqttServerOptions options, @@ -52,7 +56,7 @@ namespace MQTTnet.Server _keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger); _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, eventDispatcher); - _pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger); + _pendingMessagesQueue = new MqttClientSessionPendingMessagesQueue(_options, this, _logger); } public string ClientId { get; } @@ -63,28 +67,38 @@ namespace MQTTnet.Server status.IsConnected = _cancellationTokenSource != null; status.Endpoint = _channelAdapter?.Endpoint; status.ProtocolVersion = _channelAdapter?.PacketFormatterAdapter?.ProtocolVersion; - status.PendingApplicationMessagesCount = _pendingPacketsQueue.Count; + status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count; + status.ReceivedApplicationMessagesCount = _pendingMessagesQueue.SentMessagesCount; + status.SentApplicationMessagesCount = Interlocked.Read(ref _receivedMessagesCount); status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived; status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; } - public Task StopAsync(MqttClientDisconnectType type) + public async Task StopAsync(MqttClientDisconnectType type) { - return StopAsync(type, false); + StopInternal(type); + + var task = _workerTask; + if (task != null && !task.IsCompleted) + { + await task.ConfigureAwait(false); + } } - public async Task SubscribeAsync(IList topicFilters) + public async Task SubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + var topicFiltersCollection = topicFilters.ToList(); + var packet = new MqttSubscribePacket(); - packet.TopicFilters.AddRange(topicFilters); + packet.TopicFilters.AddRange(topicFiltersCollection); await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false); - await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false); + await EnqueueSubscribedRetainedMessagesAsync(topicFiltersCollection).ConfigureAwait(false); } - public Task UnsubscribeAsync(IList topicFilters) + public Task UnsubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -98,12 +112,12 @@ namespace MQTTnet.Server public void ClearPendingApplicationMessages() { - _pendingPacketsQueue.Clear(); + _pendingMessagesQueue.Clear(); } public void Dispose() { - _pendingPacketsQueue?.Dispose(); + _pendingMessagesQueue?.Dispose(); _cancellationTokenSource?.Cancel(); _cancellationTokenSource?.Dispose(); @@ -161,7 +175,7 @@ namespace MQTTnet.Server publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel; } - _pendingPacketsQueue.Enqueue(publishPacket); + _pendingMessagesQueue.Enqueue(publishPacket); } private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) @@ -170,26 +184,41 @@ namespace MQTTnet.Server try { - channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted; - channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted; + _logger.Info("Client '{0}': Connected.", ClientId); + _eventDispatcher.OnClientConnected(ClientId); + + _channelAdapter = channelAdapter; + + _channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted; + _channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted; - _cancellationTokenSource = new CancellationTokenSource(); + var cancellationTokenSource = new CancellationTokenSource(); + _cancellationTokenSource = cancellationTokenSource; _wasCleanDisconnect = false; _willMessage = connectPacket.WillMessage; - _pendingPacketsQueue.Start(channelAdapter, _cancellationTokenSource.Token); - _keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token); + _pendingMessagesQueue.Start(channelAdapter, cancellationTokenSource.Token); + _keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, cancellationTokenSource.Token); - _channelAdapter = channelAdapter; + await channelAdapter.SendPacketAsync( + new MqttConnAckPacket + { + ReturnCode = MqttConnectReturnCode.ConnectionAccepted, + ReasonCode = MqttConnectReasonCode.Success, + IsSessionPresent = _isCleanSession + }, + cancellationTokenSource.Token).ConfigureAwait(false); - while (!_cancellationTokenSource.IsCancellationRequested) + _isCleanSession = false; + + while (!cancellationTokenSource.IsCancellationRequested) { - var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false); + var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, cancellationTokenSource.Token).ConfigureAwait(false); if (packet != null) { _keepAliveMonitor.PacketReceived(packet); - await ProcessReceivedPacketAsync(channelAdapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false); + await ProcessReceivedPacketAsync(channelAdapter, packet, cancellationTokenSource.Token).ConfigureAwait(false); } } } @@ -203,6 +232,9 @@ namespace MQTTnet.Server if (exception is MqttCommunicationClosedGracefullyException) { _logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId); + + StopInternal(MqttClientDisconnectType.Clean); + return; } else { @@ -214,69 +246,50 @@ namespace MQTTnet.Server _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); } - await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false); + StopInternal(MqttClientDisconnectType.NotClean); } finally { - channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; - channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; - - _cancellationTokenSource?.Cancel(false); - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; - - _workerTask = null; - } - } - - private async Task StopAsync(MqttClientDisconnectType type, bool isInsideSession) - { - try - { - var cts = _cancellationTokenSource; - if (cts == null || cts.IsCancellationRequested) - { - return; - } - - _wasCleanDisconnect = type == MqttClientDisconnectType.Clean; - - _cancellationTokenSource?.Cancel(false); - if (_willMessage != null && !_wasCleanDisconnect) { _sessionsManager.EnqueueApplicationMessage(this, _willMessage); } _willMessage = null; - - if (!isInsideSession) - { - if (_workerTask != null) - { - await _workerTask.ConfigureAwait(false); - } - } - await Task.FromResult(0); - } - finally - { + _channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; + _channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; + _channelAdapter = null; + _logger.Info("Client '{0}': Session stopped.", ClientId); _eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect); + + _workerTask = null; } } + + private void StopInternal(MqttClientDisconnectType type) + { + var cts = _cancellationTokenSource; + if (cts == null || cts.IsCancellationRequested) + { + return; + } + + _wasCleanDisconnect = type == MqttClientDisconnectType.Clean; + _cancellationTokenSource?.Cancel(false); + } - private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken) + private Task ProcessReceivedPacketAsync(IMqttChannelAdapter channelAdapter, MqttBasePacket packet, CancellationToken cancellationToken) { if (packet is MqttPublishPacket publishPacket) { - return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken); + return HandleIncomingPublishPacketAsync(channelAdapter, publishPacket, cancellationToken); } if (packet is MqttPingReqPacket) { - return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); + return channelAdapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); } if (packet is MqttPubRelPacket pubRelPacket) @@ -287,7 +300,7 @@ namespace MQTTnet.Server ReasonCode = MqttPubCompReasonCode.Success }; - return adapter.SendPacketAsync(responsePacket, cancellationToken); + return channelAdapter.SendPacketAsync(responsePacket, cancellationToken); } if (packet is MqttPubRecPacket pubRecPacket) @@ -298,7 +311,7 @@ namespace MQTTnet.Server ReasonCode = MqttPubRelReasonCode.Success }; - return adapter.SendPacketAsync(responsePacket, cancellationToken); + return channelAdapter.SendPacketAsync(responsePacket, cancellationToken); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) @@ -308,27 +321,24 @@ namespace MQTTnet.Server if (packet is MqttSubscribePacket subscribePacket) { - return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken); + return HandleIncomingSubscribePacketAsync(channelAdapter, subscribePacket, cancellationToken); } if (packet is MqttUnsubscribePacket unsubscribePacket) { - return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken); + return HandleIncomingUnsubscribePacketAsync(channelAdapter, unsubscribePacket, cancellationToken); } if (packet is MqttDisconnectPacket) { - return StopAsync(MqttClientDisconnectType.Clean, true); - } - - if (packet is MqttConnectPacket) - { - return StopAsync(MqttClientDisconnectType.NotClean, true); + StopInternal(MqttClientDisconnectType.Clean); + return Task.FromResult(0); } - _logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); + _logger.Warning(null, "Client '{0}': Received invalid packet ({1}). Closing connection.", ClientId, packet); - return StopAsync(MqttClientDisconnectType.NotClean, true); + StopInternal(MqttClientDisconnectType.NotClean); + return Task.FromResult(0); } private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) @@ -347,7 +357,8 @@ namespace MQTTnet.Server if (subscribeResult.CloseConnection) { - await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false); + StopInternal(MqttClientDisconnectType.NotClean); + return; } await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); @@ -361,12 +372,13 @@ namespace MQTTnet.Server private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { + Interlocked.Increment(ref _receivedMessagesCount); + switch (publishPacket.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: { - HandleIncomingPublishPacketWithQoS0(publishPacket); - return Task.FromResult(0); + return HandleIncomingPublishPacketWithQoS0Async(publishPacket); } case MqttQualityOfServiceLevel.AtLeastOnce: { @@ -383,11 +395,13 @@ namespace MQTTnet.Server } } - private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket) + private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket) { _sessionsManager.EnqueueApplicationMessage( this, _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); + + return Task.FromResult(0); } private Task HandleIncomingPublishPacketWithQoS1Async( diff --git a/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs b/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs similarity index 88% rename from Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs rename to Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs index eddd788..60d5869 100644 --- a/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs +++ b/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs @@ -11,22 +11,24 @@ using MQTTnet.Protocol; namespace MQTTnet.Server { - public class MqttClientPendingPacketsQueue : IDisposable + public class MqttClientSessionPendingMessagesQueue : IDisposable { private readonly Queue _queue = new Queue(); - private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); + private readonly AsyncAutoResetEvent _queueLock = new AsyncAutoResetEvent(); private readonly IMqttServerOptions _options; private readonly MqttClientSession _clientSession; private readonly IMqttNetChildLogger _logger; - public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) + private long _sentPacketsCount; + + public MqttClientSessionPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _options = options ?? throw new ArgumentNullException(nameof(options)); _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); - _logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue)); + _logger = logger.CreateChildLogger(nameof(MqttClientSessionPendingMessagesQueue)); } public int Count @@ -40,6 +42,8 @@ namespace MQTTnet.Server } } + public long SentMessagesCount => Interlocked.Read(ref _sentPacketsCount); + public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken) { if (adapter == null) throw new ArgumentNullException(nameof(adapter)); @@ -52,7 +56,7 @@ namespace MQTTnet.Server Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken); } - public void Enqueue(MqttBasePacket packet) + public void Enqueue(MqttPublishPacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); @@ -70,11 +74,11 @@ namespace MQTTnet.Server _queue.Dequeue(); } } - + _queue.Enqueue(packet); } - _queueAutoResetEvent.Set(); + _queueLock.Set(); _logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId); } @@ -114,6 +118,11 @@ namespace MQTTnet.Server MqttBasePacket packet = null; try { + if (cancellationToken.IsCancellationRequested) + { + return; + } + lock (_queue) { if (_queue.Count > 0) @@ -124,18 +133,15 @@ namespace MQTTnet.Server if (packet == null) { - await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false); - return; - } - - if (cancellationToken.IsCancellationRequested) - { + await _queueLock.WaitOneAsync(cancellationToken).ConfigureAwait(false); return; } await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); + + Interlocked.Increment(ref _sentPacketsCount); } catch (Exception exception) { diff --git a/Source/MQTTnet/Server/MqttClientSessionStatus.cs b/Source/MQTTnet/Server/MqttClientSessionStatus.cs index 2673d7e..007a19b 100644 --- a/Source/MQTTnet/Server/MqttClientSessionStatus.cs +++ b/Source/MQTTnet/Server/MqttClientSessionStatus.cs @@ -21,7 +21,9 @@ namespace MQTTnet.Server public MqttProtocolVersion? ProtocolVersion { get; set; } public TimeSpan LastPacketReceived { get; set; } public TimeSpan LastNonKeepAlivePacketReceived { get; set; } - public int PendingApplicationMessagesCount { get; set; } + public long PendingApplicationMessagesCount { get; set; } + public long ReceivedApplicationMessagesCount { get; set; } + public long SentApplicationMessagesCount { get; set; } public Task DisconnectAsync() { diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 9707e58..35379ce 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -49,27 +50,31 @@ namespace MQTTnet.Server public async Task StopAsync() { - using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) + List sessions; + using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { - foreach (var session in _sessions) - { - await session.Value.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); - } + sessions = _sessions.Values.ToList(); + } - _sessions.Clear(); + foreach (var session in sessions) + { + await session.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); } } public Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter) { - return Task.Run(() => HandleConnectionAsync(clientAdapter, _cancellationToken), _cancellationToken); + return HandleConnectionAsync(clientAdapter, _cancellationToken); + + // TODO: Check if Task.Run is required. + //return Task.Run(() => HandleConnectionAsync(clientAdapter, _cancellationToken), _cancellationToken); } public async Task> GetClientStatusAsync() { var result = new List(); - using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) + using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { foreach (var session in _sessions.Values) { @@ -90,42 +95,47 @@ namespace MQTTnet.Server _messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken); } - public Task SubscribeAsync(string clientId, IList topicFilters) + public async Task SubscribeAsync(string clientId, IEnumerable topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - lock (_sessions) + using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { if (!_sessions.TryGetValue(clientId, out var session)) { throw new InvalidOperationException($"Client session '{clientId}' is unknown."); } - return session.SubscribeAsync(topicFilters); + await session.SubscribeAsync(topicFilters).ConfigureAwait(false); } } - public Task UnsubscribeAsync(string clientId, IList topicFilters) + public async Task UnsubscribeAsync(string clientId, IEnumerable topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - lock (_sessions) + using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { if (!_sessions.TryGetValue(clientId, out var session)) { throw new InvalidOperationException($"Client session '{clientId}' is unknown."); } - return session.UnsubscribeAsync(topicFilters); + await session.UnsubscribeAsync(topicFilters).ConfigureAwait(false); } } public async Task DeleteSessionAsync(string clientId) { - using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) + using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { + if (_sessions.TryGetValue(clientId, out var session)) + { + session.Dispose(); + } + _sessions.Remove(clientId); } @@ -187,7 +197,7 @@ namespace MQTTnet.Server await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); } - using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) + using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { foreach (var clientSession in _sessions.Values) { @@ -207,49 +217,37 @@ namespace MQTTnet.Server } } - private async Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) + private async Task HandleConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) { var clientId = string.Empty; - + try { - var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); + var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); if (!(firstPacket is MqttConnectPacket connectPacket)) { - _logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", clientAdapter.Endpoint); + _logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", channelAdapter.Endpoint); return; } clientId = connectPacket.ClientId; - var connectReturnCode = await ValidateConnectionAsync(connectPacket, clientAdapter).ConfigureAwait(false); + var connectReturnCode = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await clientAdapter.SendPacketAsync( + await channelAdapter.SendPacketAsync( new MqttConnAckPacket { - ReturnCode = connectReturnCode + ReturnCode = connectReturnCode, + ReasonCode = MqttConnectReasonCode.NotAuthorized }, cancellationToken).ConfigureAwait(false); return; } - var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); - - await clientAdapter.SendPacketAsync( - new MqttConnAckPacket - { - ReturnCode = connectReturnCode, - ReasonCode = MqttConnectReasonCode.Success, - IsSessionPresent = result.IsExistingSession - }, - cancellationToken).ConfigureAwait(false); - - _logger.Info("Client '{0}': Connected.", clientId); - _eventDispatcher.OnClientConnected(clientId); - - await result.Session.RunAsync(connectPacket, clientAdapter).ConfigureAwait(false); + var session = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); + await session.RunAsync(connectPacket, channelAdapter).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -260,12 +258,15 @@ namespace MQTTnet.Server } finally { - await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); - clientAdapter.Dispose(); + await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); if (!_options.EnablePersistentSessions) { - await DeleteSessionAsync(clientId).ConfigureAwait(false); + // TODO: Check if the session will be used later. + // Consider reference counter or "Recycle" property + // Or add timer (will be required for MQTTv5 (session life time) "IsActiveProperty". + //öö + //await DeleteSessionAsync(clientId).ConfigureAwait(false); } } } @@ -288,18 +289,17 @@ namespace MQTTnet.Server return context.ReturnCode; } - private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) + private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) { - using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) + using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); if (isSessionPresent) { - await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); - if (connectPacket.CleanSession) { - _sessions.Remove(connectPacket.ClientId); + await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); + clientSession.Dispose(); clientSession = null; @@ -307,22 +307,21 @@ namespace MQTTnet.Server } else { + await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); + _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId); } } - var isExistingSession = true; if (clientSession == null) { - isExistingSession = false; - clientSession = new MqttClientSession(connectPacket.ClientId, _options, this, _retainedMessagesManager, _eventDispatcher, _logger); _sessions[connectPacket.ClientId] = clientSession; _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } - return new PrepareClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; + return clientSession; } } @@ -338,5 +337,21 @@ namespace MQTTnet.Server await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); return interceptorContext; } + + private async Task TryCleanupChannelAsync(IMqttChannelAdapter channelAdapter) + { + try + { + await channelAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); + } + catch (Exception exception) + { + _logger.Error(exception, "Error while disconnecting client channel."); + } + finally + { + channelAdapter.Dispose(); + } + } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index 1b5fb3b..ffba8fb 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -56,7 +56,7 @@ namespace MQTTnet.Server return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult(); } - public Task SubscribeAsync(string clientId, IList topicFilters) + public Task SubscribeAsync(string clientId, IEnumerable topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -64,7 +64,7 @@ namespace MQTTnet.Server return _clientSessionsManager.SubscribeAsync(clientId, topicFilters); } - public Task UnsubscribeAsync(string clientId, IList topicFilters) + public Task UnsubscribeAsync(string clientId, IEnumerable topicFilters) { if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -118,7 +118,7 @@ namespace MQTTnet.Server _cancellationTokenSource.Cancel(false); - _clientSessionsManager.StopAsync().ConfigureAwait(false); + await _clientSessionsManager.StopAsync().ConfigureAwait(false); foreach (var adapter in _adapters) { diff --git a/Source/MQTTnet/Server/MqttServerExtensions.cs b/Source/MQTTnet/Server/MqttServerExtensions.cs new file mode 100644 index 0000000..e320f13 --- /dev/null +++ b/Source/MQTTnet/Server/MqttServerExtensions.cs @@ -0,0 +1,45 @@ +using System; +using System.Threading.Tasks; +using MQTTnet.Protocol; + +namespace MQTTnet.Server +{ + public static class MqttServerExtensions + { + public static Task SubscribeAsync(this IMqttServer server, string clientId, params TopicFilter[] topicFilters) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + return server.SubscribeAsync(clientId, topicFilters); + } + + public static Task SubscribeAsync(this IMqttServer server, string clientId, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return server.SubscribeAsync(clientId, new TopicFilterBuilder().WithTopic(topic).WithQualityOfServiceLevel(qualityOfServiceLevel).Build()); + } + + public static Task SubscribeAsync(this IMqttServer server, string clientId, string topic) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topic == null) throw new ArgumentNullException(nameof(topic)); + + return server.SubscribeAsync(clientId, new TopicFilterBuilder().WithTopic(topic).Build()); + } + + public static Task UnsubscribeAsync(this IMqttServer server, string clientId, params string[] topicFilters) + { + if (server == null) throw new ArgumentNullException(nameof(server)); + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + return server.UnsubscribeAsync(clientId, topicFilters); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index b3f9f21..915bc80 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -17,7 +17,7 @@ using MQTTnet.Server; namespace MQTTnet.Tests { [TestClass] - public class MqttServerTests + public partial class MqttServerTests { [TestMethod] public void MqttServer_PublishSimple_AtMostOnce() @@ -263,6 +263,91 @@ namespace MQTTnet.Tests } } + [TestMethod] + public async Task MqttServer_No_Messages_If_No_Subscription() + { + var server = new MqttFactory().CreateMqttServer(); + try + { + await server.StartAsync(new MqttServerOptions()); + + var client = new MqttFactory().CreateMqttClient(); + var receivedMessages = new List(); + + var options = new MqttClientOptionsBuilder() + .WithTcpServer("localhost").Build(); + + client.Connected += async (s, e) => + { + await client.PublishAsync("Connected"); + }; + + client.ApplicationMessageReceived += (s, e) => + { + lock (receivedMessages) + { + receivedMessages.Add(e.ApplicationMessage); + } + }; + + await client.ConnectAsync(options); + + await Task.Delay(500); + + await client.PublishAsync("Hello"); + + await Task.Delay(500); + + Assert.AreEqual(0, receivedMessages.Count); + } + finally + { + await server.StopAsync(); + } + } + + [TestMethod] + public async Task MqttServer_Set_Subscription_At_Server() + { + var server = new MqttFactory().CreateMqttServer(); + try + { + await server.StartAsync(new MqttServerOptions()); + server.ClientConnected += async (s, e) => + { + await server.SubscribeAsync(e.ClientId, "topic1"); + }; + + var client = new MqttFactory().CreateMqttClient(); + var receivedMessages = new List(); + + var options = new MqttClientOptionsBuilder() + .WithTcpServer("localhost").Build(); + + client.ApplicationMessageReceived += (s, e) => + { + lock (receivedMessages) + { + receivedMessages.Add(e.ApplicationMessage); + } + }; + + await client.ConnectAsync(options); + + await Task.Delay(500); + + await client.PublishAsync("Hello"); + + await Task.Delay(500); + + Assert.AreEqual(0, receivedMessages.Count); + } + finally + { + await server.StopAsync(); + } + } + private static async Task Publish(IMqttClient c1, MqttApplicationMessage message) { for (int i = 0; i < 1000; i++) @@ -302,40 +387,29 @@ namespace MQTTnet.Tests [TestMethod] public async Task MqttServer_Handle_Clean_Disconnect() { - var s = new MqttFactory().CreateMqttServer(); - try + using (var testSetup = new TestSetup()) { + var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); + var clientConnectedCalled = 0; var clientDisconnectedCalled = 0; - s.ClientConnected += (_, __) => clientConnectedCalled++; - s.ClientDisconnected += (_, __) => clientDisconnectedCalled++; - - var clientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("localhost") - .Build(); + server.ClientConnected += (_, __) => Interlocked.Increment(ref clientConnectedCalled); + server.ClientDisconnected += (_, __) => Interlocked.Increment(ref clientDisconnectedCalled); + + var c1 = await testSetup.ConnectClient(new MqttClientOptionsBuilder()); - await s.StartAsync(new MqttServerOptions()); + Assert.AreEqual(1, clientConnectedCalled); + Assert.AreEqual(0, clientDisconnectedCalled); - var c1 = new MqttFactory().CreateMqttClient(); - - await c1.ConnectAsync(clientOptions); - - await Task.Delay(100); + await Task.Delay(500); await c1.DisconnectAsync(); - await Task.Delay(100); - - await s.StopAsync(); - - await Task.Delay(100); + await Task.Delay(500); - Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled); - } - finally - { - await s.StopAsync(); + Assert.AreEqual(1, clientConnectedCalled); + Assert.AreEqual(1, clientDisconnectedCalled); } } @@ -385,7 +459,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_LotsOfRetainedMessages() + public async Task MqttServer_Lots_Of_Retained_Messages() { const int ClientCount = 100; @@ -745,58 +819,64 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_SameClientIdConnectDisconnectEventOrder() + public async Task MqttServer_Same_Client_Id_Connect_Disconnect_Event_Order() { - var s = new MqttFactory().CreateMqttServer(); + using (var testSetup = new TestSetup()) + { + var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); - var events = new List(); + var events = new List(); - s.ClientConnected += (_, __) => - { - lock (events) + server.ClientConnected += (_, __) => { - events.Add("c"); - } - }; + lock (events) + { + events.Add("c"); + } + }; - s.ClientDisconnected += (_, __) => - { - lock (events) + server.ClientDisconnected += (_, __) => { - events.Add("d"); - } - }; - - var clientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("localhost") - .WithClientId("same_id") - .Build(); + lock (events) + { + events.Add("d"); + } + }; - await s.StartAsync(new MqttServerOptions()); + var clientOptions = new MqttClientOptionsBuilder() + .WithClientId("same_id"); - var c1 = new MqttFactory().CreateMqttClient(); - var c2 = new MqttFactory().CreateMqttClient(); + // c + var c1 = await testSetup.ConnectClient(clientOptions); + + await Task.Delay(500); - await c1.ConnectAsync(clientOptions); + var flow = string.Join(string.Empty, events); + Assert.AreEqual("c", flow); - await Task.Delay(250); + // dc + var c2 = await testSetup.ConnectClient(clientOptions); - await c2.ConnectAsync(clientOptions); + await Task.Delay(500); - await Task.Delay(250); + flow = string.Join(string.Empty, events); + Assert.AreEqual("cdc", flow); - await c1.DisconnectAsync(); + // nothing + await c1.DisconnectAsync(); - await Task.Delay(250); + await Task.Delay(500); - await c2.DisconnectAsync(); + // d + await c2.DisconnectAsync(); - await Task.Delay(250); + await Task.Delay(500); - await s.StopAsync(); + await server.StopAsync(); - var flow = string.Join(string.Empty, events); - Assert.AreEqual("cdcd", flow); + flow = string.Join(string.Empty, events); + Assert.AreEqual("cdcd", flow); + } } [TestMethod] diff --git a/Tests/MQTTnet.Core.Tests/TestSetup.cs b/Tests/MQTTnet.Core.Tests/TestSetup.cs new file mode 100644 index 0000000..2c07170 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/TestSetup.cs @@ -0,0 +1,92 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Client; +using MQTTnet.Client.Options; +using MQTTnet.Diagnostics; +using MQTTnet.Server; + +namespace MQTTnet.Tests +{ + public class TestSetup : IDisposable + { + private readonly MqttFactory _mqttFactory = new MqttFactory(); + private readonly List _clients = new List(); + private readonly IMqttNetLogger _serverLogger = new MqttNetLogger("server"); + private readonly IMqttNetLogger _clientLogger = new MqttNetLogger("client"); + + private IMqttServer _server; + + private long _serverErrorsCount; + private long _clientErrorsCount; + + public TestSetup() + { + _serverLogger.LogMessagePublished += (s, e) => + { + if (e.TraceMessage.Level == MqttNetLogLevel.Error) + { + Interlocked.Increment(ref _serverErrorsCount); + } + }; + + _clientLogger.LogMessagePublished += (s, e) => + { + if (e.TraceMessage.Level == MqttNetLogLevel.Error) + { + Interlocked.Increment(ref _clientErrorsCount); + } + }; + } + + public async Task StartServerAsync(MqttServerOptionsBuilder options) + { + if (_server != null) + { + throw new InvalidOperationException("Server already started."); + } + + _server = _mqttFactory.CreateMqttServer(_serverLogger); + await _server.StartAsync(options.WithDefaultEndpointPort(1888).Build()); + + return _server; + } + + public async Task ConnectClient(MqttClientOptionsBuilder options) + { + var client = _mqttFactory.CreateMqttClient(_clientLogger); + _clients.Add(client); + + await client.ConnectAsync(options.WithTcpServer("localhost", 1888).Build()); + + return client; + } + + public void ThrowIfLogErrors() + { + if (_serverErrorsCount > 0) + { + throw new Exception($"Server had {_serverErrorsCount} errors."); + } + + if (_clientErrorsCount > 0) + { + throw new Exception($"Client(s) had {_clientErrorsCount} errors."); + } + } + + public void Dispose() + { + ThrowIfLogErrors(); + + foreach (var mqttClient in _clients) + { + mqttClient.DisconnectAsync().GetAwaiter().GetResult(); + mqttClient.Dispose(); + } + + _server.StopAsync().GetAwaiter().GetResult(); + } + } +}