Quellcode durchsuchen

Improve session handling.

release/3.x.x
Christian Kratky vor 5 Jahren
Ursprung
Commit
96a67579e2
13 geänderte Dateien mit 477 neuen und 218 gelöschten Zeilen
  1. +8
    -6
      Source/MQTTnet/Client/MqttClient.cs
  2. +2
    -3
      Source/MQTTnet/Client/MqttClientExtensions.cs
  3. +1
    -1
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  4. +5
    -1
      Source/MQTTnet/Server/IMqttClientSessionStatus.cs
  5. +2
    -2
      Source/MQTTnet/Server/IMqttServer.cs
  6. +92
    -78
      Source/MQTTnet/Server/MqttClientSession.cs
  7. +19
    -13
      Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs
  8. +3
    -1
      Source/MQTTnet/Server/MqttClientSessionStatus.cs
  9. +65
    -50
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  10. +3
    -3
      Source/MQTTnet/Server/MqttServer.cs
  11. +45
    -0
      Source/MQTTnet/Server/MqttServerExtensions.cs
  12. +140
    -60
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs
  13. +92
    -0
      Tests/MQTTnet.Core.Tests/TestSetup.cs

+ 8
- 6
Source/MQTTnet/Client/MqttClient.cs Datei anzeigen

@@ -66,24 +66,26 @@ namespace MQTTnet.Client
_packetIdentifierProvider.Reset();
_packetDispatcher.Reset();

_cancellationTokenSource = new CancellationTokenSource();
var cancellationTokenSource = new CancellationTokenSource();
_cancellationTokenSource = cancellationTokenSource;

_disconnectGate = 0;
_adapter = _adapterFactory.CreateClientAdapter(options, _logger);

_logger.Verbose($"Trying to connect with server ({Options.ChannelOptions}).");
await _adapter.ConnectAsync(Options.CommunicationTimeout, _cancellationTokenSource.Token).ConfigureAwait(false);
await _adapter.ConnectAsync(Options.CommunicationTimeout, cancellationTokenSource.Token).ConfigureAwait(false);
_logger.Verbose("Connection with server established.");

StartReceivingPackets(_cancellationTokenSource.Token);
StartReceivingPackets(cancellationTokenSource.Token);

var connectResult = await AuthenticateAsync(options.WillMessage, _cancellationTokenSource.Token).ConfigureAwait(false);
var connectResult = await AuthenticateAsync(options.WillMessage, cancellationTokenSource.Token).ConfigureAwait(false);
_logger.Verbose("MQTT connection with server established.");

_sendTracker.Restart();

if (Options.KeepAlivePeriod != TimeSpan.Zero)
{
StartSendingKeepAliveMessages(_cancellationTokenSource.Token);
StartSendingKeepAliveMessages(cancellationTokenSource.Token);
}

IsConnected = true;
@@ -112,7 +114,7 @@ namespace MQTTnet.Client
{
_cleanDisconnectInitiated = true;

if (IsConnected && !_cancellationTokenSource.IsCancellationRequested)
if (IsConnected && _cancellationTokenSource?.IsCancellationRequested == false)
{
var disconnectPacket = CreateDisconnectPacket(options);
await SendAsync(disconnectPacket, _cancellationTokenSource.Token).ConfigureAwait(false);


+ 2
- 3
Source/MQTTnet/Client/MqttClientExtensions.cs Datei anzeigen

@@ -1,5 +1,4 @@
using System;
using System.Linq;
using System.Threading.Tasks;
using MQTTnet.Client.Subscribing;
using MQTTnet.Client.Unsubscribing;
@@ -21,7 +20,7 @@ namespace MQTTnet.Client
if (client == null) throw new ArgumentNullException(nameof(client));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return client.SubscribeAsync(topicFilters.ToList());
return client.SubscribeAsync(topicFilters);
}

public static Task<MqttClientSubscribeResult> SubscribeAsync(this IMqttClient client, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel)
@@ -45,7 +44,7 @@ namespace MQTTnet.Client
if (client == null) throw new ArgumentNullException(nameof(client));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return client.UnsubscribeAsync(topicFilters.ToList());
return client.UnsubscribeAsync(topicFilters);
}
}
}

+ 1
- 1
Source/MQTTnet/Implementations/MqttTcpChannel.cs Datei anzeigen

@@ -93,7 +93,7 @@ namespace MQTTnet.Implementations
using (cancellationToken.Register(() => _socket.Dispose()))
{
await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
await _stream.FlushAsync(cancellationToken);
//await _stream.FlushAsync(cancellationToken);
}
}



+ 5
- 1
Source/MQTTnet/Server/IMqttClientSessionStatus.cs Datei anzeigen

@@ -18,7 +18,11 @@ namespace MQTTnet.Server

TimeSpan LastNonKeepAlivePacketReceived { get; }

int PendingApplicationMessagesCount { get; }
long PendingApplicationMessagesCount { get; }

long ReceivedApplicationMessagesCount { get; }

long SentApplicationMessagesCount { get; }

Task DisconnectAsync();



+ 2
- 2
Source/MQTTnet/Server/IMqttServer.cs Datei anzeigen

@@ -21,8 +21,8 @@ namespace MQTTnet.Server
IList<MqttApplicationMessage> GetRetainedMessages();
Task ClearRetainedMessagesAsync();

Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters);
Task UnsubscribeAsync(string clientId, IList<string> topicFilters);
Task SubscribeAsync(string clientId, IEnumerable<TopicFilter> topicFilters);
Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters);

Task StartAsync(IMqttServerOptions options);
Task StopAsync();


+ 92
- 78
Source/MQTTnet/Server/MqttClientSession.cs Datei anzeigen

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Adapter;
@@ -18,7 +19,7 @@ namespace MQTTnet.Server
private readonly MqttRetainedMessagesManager _retainedMessagesManager;
private readonly MqttServerEventDispatcher _eventDispatcher;
private readonly MqttClientKeepAliveMonitor _keepAliveMonitor;
private readonly MqttClientPendingPacketsQueue _pendingPacketsQueue;
private readonly MqttClientSessionPendingMessagesQueue _pendingMessagesQueue;
private readonly MqttClientSubscriptionsManager _subscriptionsManager;
private readonly MqttClientSessionsManager _sessionsManager;

@@ -31,6 +32,9 @@ namespace MQTTnet.Server
private Task _workerTask;
private IMqttChannelAdapter _channelAdapter;

private long _receivedMessagesCount;
private bool _isCleanSession = true;

public MqttClientSession(
string clientId,
IMqttServerOptions options,
@@ -52,7 +56,7 @@ namespace MQTTnet.Server

_keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger);
_subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, eventDispatcher);
_pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger);
_pendingMessagesQueue = new MqttClientSessionPendingMessagesQueue(_options, this, _logger);
}

public string ClientId { get; }
@@ -63,28 +67,38 @@ namespace MQTTnet.Server
status.IsConnected = _cancellationTokenSource != null;
status.Endpoint = _channelAdapter?.Endpoint;
status.ProtocolVersion = _channelAdapter?.PacketFormatterAdapter?.ProtocolVersion;
status.PendingApplicationMessagesCount = _pendingPacketsQueue.Count;
status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count;
status.ReceivedApplicationMessagesCount = _pendingMessagesQueue.SentMessagesCount;
status.SentApplicationMessagesCount = Interlocked.Read(ref _receivedMessagesCount);
status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived;
status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived;
}

public Task StopAsync(MqttClientDisconnectType type)
public async Task StopAsync(MqttClientDisconnectType type)
{
return StopAsync(type, false);
StopInternal(type);

var task = _workerTask;
if (task != null && !task.IsCompleted)
{
await task.ConfigureAwait(false);
}
}

public async Task SubscribeAsync(IList<TopicFilter> topicFilters)
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

var topicFiltersCollection = topicFilters.ToList();

var packet = new MqttSubscribePacket();
packet.TopicFilters.AddRange(topicFilters);
packet.TopicFilters.AddRange(topicFiltersCollection);

await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false);
await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false);
await EnqueueSubscribedRetainedMessagesAsync(topicFiltersCollection).ConfigureAwait(false);
}

public Task UnsubscribeAsync(IList<string> topicFilters)
public Task UnsubscribeAsync(IEnumerable<string> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

@@ -98,12 +112,12 @@ namespace MQTTnet.Server

public void ClearPendingApplicationMessages()
{
_pendingPacketsQueue.Clear();
_pendingMessagesQueue.Clear();
}

public void Dispose()
{
_pendingPacketsQueue?.Dispose();
_pendingMessagesQueue?.Dispose();

_cancellationTokenSource?.Cancel();
_cancellationTokenSource?.Dispose();
@@ -161,7 +175,7 @@ namespace MQTTnet.Server
publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel;
}

_pendingPacketsQueue.Enqueue(publishPacket);
_pendingMessagesQueue.Enqueue(publishPacket);
}

private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
@@ -170,26 +184,41 @@ namespace MQTTnet.Server

try
{
channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted;
channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted;
_logger.Info("Client '{0}': Connected.", ClientId);
_eventDispatcher.OnClientConnected(ClientId);

_channelAdapter = channelAdapter;

_channelAdapter.ReadingPacketStarted += OnAdapterReadingPacketStarted;
_channelAdapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted;

_cancellationTokenSource = new CancellationTokenSource();
var cancellationTokenSource = new CancellationTokenSource();
_cancellationTokenSource = cancellationTokenSource;

_wasCleanDisconnect = false;
_willMessage = connectPacket.WillMessage;

_pendingPacketsQueue.Start(channelAdapter, _cancellationTokenSource.Token);
_keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token);
_pendingMessagesQueue.Start(channelAdapter, cancellationTokenSource.Token);
_keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, cancellationTokenSource.Token);

_channelAdapter = channelAdapter;
await channelAdapter.SendPacketAsync(
new MqttConnAckPacket
{
ReturnCode = MqttConnectReturnCode.ConnectionAccepted,
ReasonCode = MqttConnectReasonCode.Success,
IsSessionPresent = _isCleanSession
},
cancellationTokenSource.Token).ConfigureAwait(false);

while (!_cancellationTokenSource.IsCancellationRequested)
_isCleanSession = false;

while (!cancellationTokenSource.IsCancellationRequested)
{
var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false);
var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, cancellationTokenSource.Token).ConfigureAwait(false);
if (packet != null)
{
_keepAliveMonitor.PacketReceived(packet);
await ProcessReceivedPacketAsync(channelAdapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false);
await ProcessReceivedPacketAsync(channelAdapter, packet, cancellationTokenSource.Token).ConfigureAwait(false);
}
}
}
@@ -203,6 +232,9 @@ namespace MQTTnet.Server
if (exception is MqttCommunicationClosedGracefullyException)
{
_logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId);

StopInternal(MqttClientDisconnectType.Clean);
return;
}
else
{
@@ -214,69 +246,50 @@ namespace MQTTnet.Server
_logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId);
}

await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
StopInternal(MqttClientDisconnectType.NotClean);
}
finally
{
channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;

_cancellationTokenSource?.Cancel(false);
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;

_workerTask = null;
}
}

private async Task StopAsync(MqttClientDisconnectType type, bool isInsideSession)
{
try
{
var cts = _cancellationTokenSource;
if (cts == null || cts.IsCancellationRequested)
{
return;
}

_wasCleanDisconnect = type == MqttClientDisconnectType.Clean;

_cancellationTokenSource?.Cancel(false);

if (_willMessage != null && !_wasCleanDisconnect)
{
_sessionsManager.EnqueueApplicationMessage(this, _willMessage);
}

_willMessage = null;

if (!isInsideSession)
{
if (_workerTask != null)
{
await _workerTask.ConfigureAwait(false);
}
}
await Task.FromResult(0);
}
finally
{
_channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
_channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
_channelAdapter = null;
_logger.Info("Client '{0}': Session stopped.", ClientId);
_eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect);

_workerTask = null;
}
}
private void StopInternal(MqttClientDisconnectType type)
{
var cts = _cancellationTokenSource;
if (cts == null || cts.IsCancellationRequested)
{
return;
}

_wasCleanDisconnect = type == MqttClientDisconnectType.Clean;
_cancellationTokenSource?.Cancel(false);
}

private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
private Task ProcessReceivedPacketAsync(IMqttChannelAdapter channelAdapter, MqttBasePacket packet, CancellationToken cancellationToken)
{
if (packet is MqttPublishPacket publishPacket)
{
return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken);
return HandleIncomingPublishPacketAsync(channelAdapter, publishPacket, cancellationToken);
}

if (packet is MqttPingReqPacket)
{
return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
return channelAdapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
}

if (packet is MqttPubRelPacket pubRelPacket)
@@ -287,7 +300,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubCompReasonCode.Success
};

return adapter.SendPacketAsync(responsePacket, cancellationToken);
return channelAdapter.SendPacketAsync(responsePacket, cancellationToken);
}

if (packet is MqttPubRecPacket pubRecPacket)
@@ -298,7 +311,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubRelReasonCode.Success
};

return adapter.SendPacketAsync(responsePacket, cancellationToken);
return channelAdapter.SendPacketAsync(responsePacket, cancellationToken);
}

if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
@@ -308,27 +321,24 @@ namespace MQTTnet.Server

if (packet is MqttSubscribePacket subscribePacket)
{
return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken);
return HandleIncomingSubscribePacketAsync(channelAdapter, subscribePacket, cancellationToken);
}

if (packet is MqttUnsubscribePacket unsubscribePacket)
{
return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken);
return HandleIncomingUnsubscribePacketAsync(channelAdapter, unsubscribePacket, cancellationToken);
}

if (packet is MqttDisconnectPacket)
{
return StopAsync(MqttClientDisconnectType.Clean, true);
}

if (packet is MqttConnectPacket)
{
return StopAsync(MqttClientDisconnectType.NotClean, true);
StopInternal(MqttClientDisconnectType.Clean);
return Task.FromResult(0);
}

_logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet);
_logger.Warning(null, "Client '{0}': Received invalid packet ({1}). Closing connection.", ClientId, packet);

return StopAsync(MqttClientDisconnectType.NotClean, true);
StopInternal(MqttClientDisconnectType.NotClean);
return Task.FromResult(0);
}

private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
@@ -347,7 +357,8 @@ namespace MQTTnet.Server

if (subscribeResult.CloseConnection)
{
await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
StopInternal(MqttClientDisconnectType.NotClean);
return;
}

await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false);
@@ -361,12 +372,13 @@ namespace MQTTnet.Server

private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{
Interlocked.Increment(ref _receivedMessagesCount);

switch (publishPacket.QualityOfServiceLevel)
{
case MqttQualityOfServiceLevel.AtMostOnce:
{
HandleIncomingPublishPacketWithQoS0(publishPacket);
return Task.FromResult(0);
return HandleIncomingPublishPacketWithQoS0Async(publishPacket);
}
case MqttQualityOfServiceLevel.AtLeastOnce:
{
@@ -383,11 +395,13 @@ namespace MQTTnet.Server
}
}

private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket)
private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket)
{
_sessionsManager.EnqueueApplicationMessage(
this,
_channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket));

return Task.FromResult(0);
}

private Task HandleIncomingPublishPacketWithQoS1Async(


Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs → Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs Datei anzeigen

@@ -11,22 +11,24 @@ using MQTTnet.Protocol;

namespace MQTTnet.Server
{
public class MqttClientPendingPacketsQueue : IDisposable
public class MqttClientSessionPendingMessagesQueue : IDisposable
{
private readonly Queue<MqttBasePacket> _queue = new Queue<MqttBasePacket>();
private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent();
private readonly AsyncAutoResetEvent _queueLock = new AsyncAutoResetEvent();

private readonly IMqttServerOptions _options;
private readonly MqttClientSession _clientSession;
private readonly IMqttNetChildLogger _logger;

public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger)
private long _sentPacketsCount;

public MqttClientSessionPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger)
{
if (logger == null) throw new ArgumentNullException(nameof(logger));
_options = options ?? throw new ArgumentNullException(nameof(options));
_clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession));

_logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue));
_logger = logger.CreateChildLogger(nameof(MqttClientSessionPendingMessagesQueue));
}

public int Count
@@ -40,6 +42,8 @@ namespace MQTTnet.Server
}
}

public long SentMessagesCount => Interlocked.Read(ref _sentPacketsCount);

public void Start(IMqttChannelAdapter adapter, CancellationToken cancellationToken)
{
if (adapter == null) throw new ArgumentNullException(nameof(adapter));
@@ -52,7 +56,7 @@ namespace MQTTnet.Server
Task.Run(() => SendQueuedPacketsAsync(adapter, cancellationToken), cancellationToken);
}

public void Enqueue(MqttBasePacket packet)
public void Enqueue(MqttPublishPacket packet)
{
if (packet == null) throw new ArgumentNullException(nameof(packet));

@@ -70,11 +74,11 @@ namespace MQTTnet.Server
_queue.Dequeue();
}
}
_queue.Enqueue(packet);
}

_queueAutoResetEvent.Set();
_queueLock.Set();

_logger.Verbose("Enqueued packet (ClientId: {0}).", _clientSession.ClientId);
}
@@ -114,6 +118,11 @@ namespace MQTTnet.Server
MqttBasePacket packet = null;
try
{
if (cancellationToken.IsCancellationRequested)
{
return;
}

lock (_queue)
{
if (_queue.Count > 0)
@@ -124,18 +133,15 @@ namespace MQTTnet.Server

if (packet == null)
{
await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false);
return;
}

if (cancellationToken.IsCancellationRequested)
{
await _queueLock.WaitOneAsync(cancellationToken).ConfigureAwait(false);
return;
}

await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false);

_logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId);

Interlocked.Increment(ref _sentPacketsCount);
}
catch (Exception exception)
{

+ 3
- 1
Source/MQTTnet/Server/MqttClientSessionStatus.cs Datei anzeigen

@@ -21,7 +21,9 @@ namespace MQTTnet.Server
public MqttProtocolVersion? ProtocolVersion { get; set; }
public TimeSpan LastPacketReceived { get; set; }
public TimeSpan LastNonKeepAlivePacketReceived { get; set; }
public int PendingApplicationMessagesCount { get; set; }
public long PendingApplicationMessagesCount { get; set; }
public long ReceivedApplicationMessagesCount { get; set; }
public long SentApplicationMessagesCount { get; set; }

public Task DisconnectAsync()
{


+ 65
- 50
Source/MQTTnet/Server/MqttClientSessionsManager.cs Datei anzeigen

@@ -1,6 +1,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Adapter;
@@ -49,27 +50,31 @@ namespace MQTTnet.Server

public async Task StopAsync()
{
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
List<MqttClientSession> sessions;
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
foreach (var session in _sessions)
{
await session.Value.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
}
sessions = _sessions.Values.ToList();
}

_sessions.Clear();
foreach (var session in sessions)
{
await session.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);
}
}

public Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter)
{
return Task.Run(() => HandleConnectionAsync(clientAdapter, _cancellationToken), _cancellationToken);
return HandleConnectionAsync(clientAdapter, _cancellationToken);

// TODO: Check if Task.Run is required.
//return Task.Run(() => HandleConnectionAsync(clientAdapter, _cancellationToken), _cancellationToken);
}

public async Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync()
{
var result = new List<IMqttClientSessionStatus>();

using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
foreach (var session in _sessions.Values)
{
@@ -90,42 +95,47 @@ namespace MQTTnet.Server
_messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken);
}

public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
public async Task SubscribeAsync(string clientId, IEnumerable<TopicFilter> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

lock (_sessions)
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
}

return session.SubscribeAsync(topicFilters);
await session.SubscribeAsync(topicFilters).ConfigureAwait(false);
}
}

public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
public async Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

lock (_sessions)
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
}

return session.UnsubscribeAsync(topicFilters);
await session.UnsubscribeAsync(topicFilters).ConfigureAwait(false);
}
}

public async Task DeleteSessionAsync(string clientId)
{
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
if (_sessions.TryGetValue(clientId, out var session))
{
session.Dispose();
}

_sessions.Remove(clientId);
}

@@ -187,7 +197,7 @@ namespace MQTTnet.Server
await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
}

using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
foreach (var clientSession in _sessions.Values)
{
@@ -207,49 +217,37 @@ namespace MQTTnet.Server
}
}

private async Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
private async Task HandleConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{
var clientId = string.Empty;
try
{
var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
if (!(firstPacket is MqttConnectPacket connectPacket))
{
_logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", clientAdapter.Endpoint);
_logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", channelAdapter.Endpoint);
return;
}

clientId = connectPacket.ClientId;

var connectReturnCode = await ValidateConnectionAsync(connectPacket, clientAdapter).ConfigureAwait(false);
var connectReturnCode = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);
if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
{
await clientAdapter.SendPacketAsync(
await channelAdapter.SendPacketAsync(
new MqttConnAckPacket
{
ReturnCode = connectReturnCode
ReturnCode = connectReturnCode,
ReasonCode = MqttConnectReasonCode.NotAuthorized
},
cancellationToken).ConfigureAwait(false);

return;
}

var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);

await clientAdapter.SendPacketAsync(
new MqttConnAckPacket
{
ReturnCode = connectReturnCode,
ReasonCode = MqttConnectReasonCode.Success,
IsSessionPresent = result.IsExistingSession
},
cancellationToken).ConfigureAwait(false);

_logger.Info("Client '{0}': Connected.", clientId);
_eventDispatcher.OnClientConnected(clientId);

await result.Session.RunAsync(connectPacket, clientAdapter).ConfigureAwait(false);
var session = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);
await session.RunAsync(connectPacket, channelAdapter).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
@@ -260,12 +258,15 @@ namespace MQTTnet.Server
}
finally
{
await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
clientAdapter.Dispose();
await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false);

if (!_options.EnablePersistentSessions)
{
await DeleteSessionAsync(clientId).ConfigureAwait(false);
// TODO: Check if the session will be used later.
// Consider reference counter or "Recycle" property
// Or add timer (will be required for MQTTv5 (session life time) "IsActiveProperty".
//öö
//await DeleteSessionAsync(clientId).ConfigureAwait(false);
}
}
}
@@ -288,18 +289,17 @@ namespace MQTTnet.Server
return context.ReturnCode;
}

private async Task<PrepareClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
private async Task<MqttClientSession> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
{
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
using (await _sessionsLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession);
if (isSessionPresent)
{
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);

if (connectPacket.CleanSession)
{
_sessions.Remove(connectPacket.ClientId);
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);

clientSession.Dispose();
clientSession = null;

@@ -307,22 +307,21 @@ namespace MQTTnet.Server
}
else
{
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);

_logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId);
}
}

var isExistingSession = true;
if (clientSession == null)
{
isExistingSession = false;

clientSession = new MqttClientSession(connectPacket.ClientId, _options, this, _retainedMessagesManager, _eventDispatcher, _logger);
_sessions[connectPacket.ClientId] = clientSession;

_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId);
}

return new PrepareClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession };
return clientSession;
}
}

@@ -338,5 +337,21 @@ namespace MQTTnet.Server
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
return interceptorContext;
}

private async Task TryCleanupChannelAsync(IMqttChannelAdapter channelAdapter)
{
try
{
await channelAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while disconnecting client channel.");
}
finally
{
channelAdapter.Dispose();
}
}
}
}

+ 3
- 3
Source/MQTTnet/Server/MqttServer.cs Datei anzeigen

@@ -56,7 +56,7 @@ namespace MQTTnet.Server
return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult();
}

public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
public Task SubscribeAsync(string clientId, IEnumerable<TopicFilter> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
@@ -64,7 +64,7 @@ namespace MQTTnet.Server
return _clientSessionsManager.SubscribeAsync(clientId, topicFilters);
}

public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
public Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
@@ -118,7 +118,7 @@ namespace MQTTnet.Server

_cancellationTokenSource.Cancel(false);
_clientSessionsManager.StopAsync().ConfigureAwait(false);
await _clientSessionsManager.StopAsync().ConfigureAwait(false);

foreach (var adapter in _adapters)
{


+ 45
- 0
Source/MQTTnet/Server/MqttServerExtensions.cs Datei anzeigen

@@ -0,0 +1,45 @@
using System;
using System.Threading.Tasks;
using MQTTnet.Protocol;

namespace MQTTnet.Server
{
public static class MqttServerExtensions
{
public static Task SubscribeAsync(this IMqttServer server, string clientId, params TopicFilter[] topicFilters)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return server.SubscribeAsync(clientId, topicFilters);
}

public static Task SubscribeAsync(this IMqttServer server, string clientId, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topic == null) throw new ArgumentNullException(nameof(topic));

return server.SubscribeAsync(clientId, new TopicFilterBuilder().WithTopic(topic).WithQualityOfServiceLevel(qualityOfServiceLevel).Build());
}

public static Task SubscribeAsync(this IMqttServer server, string clientId, string topic)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topic == null) throw new ArgumentNullException(nameof(topic));

return server.SubscribeAsync(clientId, new TopicFilterBuilder().WithTopic(topic).Build());
}

public static Task UnsubscribeAsync(this IMqttServer server, string clientId, params string[] topicFilters)
{
if (server == null) throw new ArgumentNullException(nameof(server));
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return server.UnsubscribeAsync(clientId, topicFilters);
}
}
}

+ 140
- 60
Tests/MQTTnet.Core.Tests/MqttServerTests.cs Datei anzeigen

@@ -17,7 +17,7 @@ using MQTTnet.Server;
namespace MQTTnet.Tests
{
[TestClass]
public class MqttServerTests
public partial class MqttServerTests
{
[TestMethod]
public void MqttServer_PublishSimple_AtMostOnce()
@@ -263,6 +263,91 @@ namespace MQTTnet.Tests
}
}

[TestMethod]
public async Task MqttServer_No_Messages_If_No_Subscription()
{
var server = new MqttFactory().CreateMqttServer();
try
{
await server.StartAsync(new MqttServerOptions());

var client = new MqttFactory().CreateMqttClient();
var receivedMessages = new List<MqttApplicationMessage>();

var options = new MqttClientOptionsBuilder()
.WithTcpServer("localhost").Build();

client.Connected += async (s, e) =>
{
await client.PublishAsync("Connected");
};

client.ApplicationMessageReceived += (s, e) =>
{
lock (receivedMessages)
{
receivedMessages.Add(e.ApplicationMessage);
}
};

await client.ConnectAsync(options);

await Task.Delay(500);

await client.PublishAsync("Hello");

await Task.Delay(500);
Assert.AreEqual(0, receivedMessages.Count);
}
finally
{
await server.StopAsync();
}
}

[TestMethod]
public async Task MqttServer_Set_Subscription_At_Server()
{
var server = new MqttFactory().CreateMqttServer();
try
{
await server.StartAsync(new MqttServerOptions());
server.ClientConnected += async (s, e) =>
{
await server.SubscribeAsync(e.ClientId, "topic1");
};

var client = new MqttFactory().CreateMqttClient();
var receivedMessages = new List<MqttApplicationMessage>();

var options = new MqttClientOptionsBuilder()
.WithTcpServer("localhost").Build();

client.ApplicationMessageReceived += (s, e) =>
{
lock (receivedMessages)
{
receivedMessages.Add(e.ApplicationMessage);
}
};

await client.ConnectAsync(options);

await Task.Delay(500);

await client.PublishAsync("Hello");

await Task.Delay(500);

Assert.AreEqual(0, receivedMessages.Count);
}
finally
{
await server.StopAsync();
}
}

private static async Task Publish(IMqttClient c1, MqttApplicationMessage message)
{
for (int i = 0; i < 1000; i++)
@@ -302,40 +387,29 @@ namespace MQTTnet.Tests
[TestMethod]
public async Task MqttServer_Handle_Clean_Disconnect()
{
var s = new MqttFactory().CreateMqttServer();
try
using (var testSetup = new TestSetup())
{
var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder());

var clientConnectedCalled = 0;
var clientDisconnectedCalled = 0;

s.ClientConnected += (_, __) => clientConnectedCalled++;
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++;

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();
server.ClientConnected += (_, __) => Interlocked.Increment(ref clientConnectedCalled);
server.ClientDisconnected += (_, __) => Interlocked.Increment(ref clientDisconnectedCalled);
var c1 = await testSetup.ConnectClient(new MqttClientOptionsBuilder());

await s.StartAsync(new MqttServerOptions());
Assert.AreEqual(1, clientConnectedCalled);
Assert.AreEqual(0, clientDisconnectedCalled);

var c1 = new MqttFactory().CreateMqttClient();

await c1.ConnectAsync(clientOptions);

await Task.Delay(100);
await Task.Delay(500);

await c1.DisconnectAsync();

await Task.Delay(100);

await s.StopAsync();

await Task.Delay(100);
await Task.Delay(500);

Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
}
finally
{
await s.StopAsync();
Assert.AreEqual(1, clientConnectedCalled);
Assert.AreEqual(1, clientDisconnectedCalled);
}
}

@@ -385,7 +459,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_LotsOfRetainedMessages()
public async Task MqttServer_Lots_Of_Retained_Messages()
{
const int ClientCount = 100;

@@ -745,58 +819,64 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_SameClientIdConnectDisconnectEventOrder()
public async Task MqttServer_Same_Client_Id_Connect_Disconnect_Event_Order()
{
var s = new MqttFactory().CreateMqttServer();
using (var testSetup = new TestSetup())
{
var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder());

var events = new List<string>();
var events = new List<string>();

s.ClientConnected += (_, __) =>
{
lock (events)
server.ClientConnected += (_, __) =>
{
events.Add("c");
}
};
lock (events)
{
events.Add("c");
}
};

s.ClientDisconnected += (_, __) =>
{
lock (events)
server.ClientDisconnected += (_, __) =>
{
events.Add("d");
}
};

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.WithClientId("same_id")
.Build();
lock (events)
{
events.Add("d");
}
};

await s.StartAsync(new MqttServerOptions());
var clientOptions = new MqttClientOptionsBuilder()
.WithClientId("same_id");

var c1 = new MqttFactory().CreateMqttClient();
var c2 = new MqttFactory().CreateMqttClient();
// c
var c1 = await testSetup.ConnectClient(clientOptions);
await Task.Delay(500);

await c1.ConnectAsync(clientOptions);
var flow = string.Join(string.Empty, events);
Assert.AreEqual("c", flow);

await Task.Delay(250);
// dc
var c2 = await testSetup.ConnectClient(clientOptions);

await c2.ConnectAsync(clientOptions);
await Task.Delay(500);

await Task.Delay(250);
flow = string.Join(string.Empty, events);
Assert.AreEqual("cdc", flow);

await c1.DisconnectAsync();
// nothing
await c1.DisconnectAsync();

await Task.Delay(250);
await Task.Delay(500);

await c2.DisconnectAsync();
// d
await c2.DisconnectAsync();

await Task.Delay(250);
await Task.Delay(500);

await s.StopAsync();
await server.StopAsync();

var flow = string.Join(string.Empty, events);
Assert.AreEqual("cdcd", flow);
flow = string.Join(string.Empty, events);
Assert.AreEqual("cdcd", flow);
}
}

[TestMethod]


+ 92
- 0
Tests/MQTTnet.Core.Tests/TestSetup.cs Datei anzeigen

@@ -0,0 +1,92 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Client;
using MQTTnet.Client.Options;
using MQTTnet.Diagnostics;
using MQTTnet.Server;

namespace MQTTnet.Tests
{
public class TestSetup : IDisposable
{
private readonly MqttFactory _mqttFactory = new MqttFactory();
private readonly List<IMqttClient> _clients = new List<IMqttClient>();
private readonly IMqttNetLogger _serverLogger = new MqttNetLogger("server");
private readonly IMqttNetLogger _clientLogger = new MqttNetLogger("client");

private IMqttServer _server;

private long _serverErrorsCount;
private long _clientErrorsCount;

public TestSetup()
{
_serverLogger.LogMessagePublished += (s, e) =>
{
if (e.TraceMessage.Level == MqttNetLogLevel.Error)
{
Interlocked.Increment(ref _serverErrorsCount);
}
};

_clientLogger.LogMessagePublished += (s, e) =>
{
if (e.TraceMessage.Level == MqttNetLogLevel.Error)
{
Interlocked.Increment(ref _clientErrorsCount);
}
};
}

public async Task<IMqttServer> StartServerAsync(MqttServerOptionsBuilder options)
{
if (_server != null)
{
throw new InvalidOperationException("Server already started.");
}

_server = _mqttFactory.CreateMqttServer(_serverLogger);
await _server.StartAsync(options.WithDefaultEndpointPort(1888).Build());

return _server;
}

public async Task<IMqttClient> ConnectClient(MqttClientOptionsBuilder options)
{
var client = _mqttFactory.CreateMqttClient(_clientLogger);
_clients.Add(client);

await client.ConnectAsync(options.WithTcpServer("localhost", 1888).Build());

return client;
}

public void ThrowIfLogErrors()
{
if (_serverErrorsCount > 0)
{
throw new Exception($"Server had {_serverErrorsCount} errors.");
}

if (_clientErrorsCount > 0)
{
throw new Exception($"Client(s) had {_clientErrorsCount} errors.");
}
}

public void Dispose()
{
ThrowIfLogErrors();

foreach (var mqttClient in _clients)
{
mqttClient.DisconnectAsync().GetAwaiter().GetResult();
mqttClient.Dispose();
}

_server.StopAsync().GetAwaiter().GetResult();
}
}
}

Laden…
Abbrechen
Speichern