Pārlūkot izejas kodu

Merge pull request #343 from JanEggers/ServerShutdown

Server shutdown
release/3.x.x
Christian pirms 6 gadiem
committed by GitHub
vecāks
revīzija
1e2fee0f33
Šim parakstam datu bāzē netika atrasta zināma atslēga GPG atslēgas ID: 4AEE18F83AFDEB23
8 mainītis faili ar 203 papildinājumiem un 42 dzēšanām
  1. +21
    -7
      Source/MQTTnet/Client/MqttClient.cs
  2. +15
    -8
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  3. +3
    -0
      Source/MQTTnet/Properties/AssemblyInfo.cs
  4. +53
    -11
      Source/MQTTnet/Server/MqttClientSession.cs
  5. +0
    -10
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  6. +2
    -2
      Source/MQTTnet/Server/MqttServer.cs
  7. +41
    -4
      Tests/MQTTnet.Core.Tests/MqttClientTests.cs
  8. +68
    -0
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs

+ 21
- 7
Source/MQTTnet/Client/MqttClient.cs Parādīt failu

@@ -25,10 +25,11 @@ namespace MQTTnet.Client

private IMqttClientOptions _options;
private CancellationTokenSource _cancellationTokenSource;
private Task _packetReceiverTask;
private Task _keepAliveMessageSenderTask;
internal Task _packetReceiverTask;
internal Task _keepAliveMessageSenderTask;
private IMqttChannelAdapter _adapter;
private bool _cleanDisconnectInitiated;
private TaskCompletionSource<bool> _disconnectReason;

public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger)
{
@@ -54,6 +55,7 @@ namespace MQTTnet.Client
try
{
_cancellationTokenSource = new CancellationTokenSource();
_disconnectReason = new TaskCompletionSource<bool>();
_options = options;
_packetIdentifierProvider.Reset();
_packetDispatcher.Reset();
@@ -85,8 +87,10 @@ namespace MQTTnet.Client
catch (Exception exception)
{
_logger.Error(exception, "Error while connecting with server.");
await DisconnectInternalAsync(null, exception).ConfigureAwait(false);

if (_disconnectReason.TrySetException(exception))
{
await DisconnectInternalAsync(null, exception).ConfigureAwait(false);
}
throw;
}
}
@@ -104,7 +108,10 @@ namespace MQTTnet.Client
}
finally
{
await DisconnectInternalAsync(null, null).ConfigureAwait(false);
if (_disconnectReason.TrySetCanceled())
{
await DisconnectInternalAsync(null, null).ConfigureAwait(false);
}
}
}

@@ -352,7 +359,10 @@ namespace MQTTnet.Client
_logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets.");
}

await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false);
if (_disconnectReason.TrySetException(exception))
{
await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false);
}
}
finally
{
@@ -395,7 +405,11 @@ namespace MQTTnet.Client
}

_packetDispatcher.Dispatch(exception);
await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false);

if (_disconnectReason.TrySetException(exception))
{
await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false);
}
}
finally
{


+ 15
- 8
Source/MQTTnet/Implementations/MqttTcpChannel.cs Parādīt failu

@@ -87,8 +87,14 @@ namespace MQTTnet.Implementations

public void Dispose()
{
TryDispose(_stream, () => _stream = null);
TryDispose(_socket, () => _socket = null);
Cleanup(ref _stream, (s) => s.Dispose());
Cleanup(ref _socket, (s) => {
if (s.Connected)
{
s.Shutdown(SocketShutdown.Both);
}
s.Dispose();
});
}

private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
@@ -157,11 +163,16 @@ namespace MQTTnet.Implementations
}
}

private static void TryDispose(IDisposable disposable, Action afterDispose)
private static void Cleanup<T>(ref T item, Action<T> handler) where T : class
{
var temp = item;
item = null;
try
{
disposable?.Dispose();
if (temp != null)
{
handler(temp);
}
}
catch (ObjectDisposedException)
{
@@ -169,10 +180,6 @@ namespace MQTTnet.Implementations
catch (NullReferenceException)
{
}
finally
{
afterDispose();
}
}
}
}


+ 3
- 0
Source/MQTTnet/Properties/AssemblyInfo.cs Parādīt failu

@@ -0,0 +1,3 @@
using System.Runtime.CompilerServices;

[assembly:InternalsVisibleTo("MQTTnet.Core.Tests")]

+ 53
- 11
Source/MQTTnet/Server/MqttClientSession.cs Parādīt failu

@@ -29,6 +29,8 @@ namespace MQTTnet.Server
private MqttApplicationMessage _willMessage;
private bool _wasCleanDisconnect;
private IMqttChannelAdapter _adapter;
private Task<bool> _workerTask;
private IDisposable _cleanupHandle;

public MqttClientSession(
string clientId,
@@ -65,7 +67,13 @@ namespace MQTTnet.Server
status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived;
}

public async Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter)
public Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter)
{
_workerTask = RunInternalAsync(connectPacket, adapter);
return _workerTask;
}

private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter)
{
if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket));
if (adapter == null) throw new ArgumentNullException(nameof(adapter));
@@ -77,6 +85,10 @@ namespace MQTTnet.Server
adapter.ReadingPacketCompleted += OnAdapterReadingPacketCompleted;

_cancellationTokenSource = new CancellationTokenSource();

//woraround for https://github.com/dotnet/corefx/issues/24430
_cleanupHandle = _cancellationTokenSource.Token.Register(() => Cleanup());
//endworkaround
_wasCleanDisconnect = false;
_willMessage = connectPacket.WillMessage;

@@ -114,24 +126,49 @@ namespace MQTTnet.Server
_logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId);
}

Stop(MqttClientDisconnectType.NotClean);
Stop(MqttClientDisconnectType.NotClean, true);
}
finally
{
if (_adapter != null)
await Cleanup().ConfigureAwait(false);

_cleanupHandle?.Dispose();
_cleanupHandle = null;

_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;
}
}

private async Task Cleanup()
{
try
{
var adapter = _adapter;
if (adapter == null)
{
_adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
_adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
return;
}

_adapter = null;

_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;
adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
adapter.Dispose();
}
catch (Exception exception)
{
_logger.Error(exception, exception.Message);
}
}

public void Stop(MqttClientDisconnectType type)
{
Stop(type, false);
}

private void Stop(MqttClientDisconnectType type, bool isInsideSession)
{
try
{
@@ -151,6 +188,11 @@ namespace MQTTnet.Server
}

_willMessage = null;

if (!isInsideSession)
{
_workerTask?.GetAwaiter().GetResult();
}
}
finally
{
@@ -298,18 +340,18 @@ namespace MQTTnet.Server

if (packet is MqttDisconnectPacket)
{
Stop(MqttClientDisconnectType.Clean);
Stop(MqttClientDisconnectType.Clean, true);
return;
}

if (packet is MqttConnectPacket)
{
Stop(MqttClientDisconnectType.NotClean);
Stop(MqttClientDisconnectType.NotClean, true);
return;
}

_logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet);
Stop(MqttClientDisconnectType.NotClean);
Stop(MqttClientDisconnectType.NotClean, true);
}

private void EnqueueSubscribedRetainedMessages(ICollection<TopicFilter> topicFilters)
@@ -328,7 +370,7 @@ namespace MQTTnet.Server

if (subscribeResult.CloseConnection)
{
Stop(MqttClientDisconnectType.NotClean);
Stop(MqttClientDisconnectType.NotClean, true);
return;
}



+ 0
- 10
Source/MQTTnet/Server/MqttClientSessionsManager.cs Parādīt failu

@@ -264,16 +264,6 @@ namespace MQTTnet.Server
}
finally
{
try
{
await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
clientAdapter.Dispose();
}
catch (Exception exception)
{
_logger.Error(exception, exception.Message);
}

if (!_options.EnablePersistentSessions)
{
DeleteSession(clientId);


+ 2
- 2
Source/MQTTnet/Server/MqttServer.cs Parādīt failu

@@ -106,14 +106,14 @@ namespace MQTTnet.Server

_cancellationTokenSource.Cancel(false);
_clientSessionsManager.Stop();

foreach (var adapter in _adapters)
{
adapter.ClientAccepted -= OnClientAccepted;
await adapter.StopAsync().ConfigureAwait(false);
}

_clientSessionsManager.Stop();

_logger.Info("Stopped.");
Stopped?.Invoke(this, EventArgs.Empty);
}


+ 41
- 4
Tests/MQTTnet.Core.Tests/MqttClientTests.cs Parādīt failu

@@ -1,10 +1,13 @@
using System;
using System;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Adapter;
using MQTTnet.Client;
using MQTTnet.Diagnostics;
using MQTTnet.Exceptions;
using MQTTnet.Packets;
using MQTTnet.Implementations;
using MQTTnet.Server;

@@ -19,10 +22,10 @@ namespace MQTTnet.Core.Tests
var factory = new MqttFactory();
var client = factory.CreateMqttClient();

var exceptionIsCorrect = false;
Exception ex = null;
client.Disconnected += (s, e) =>
{
exceptionIsCorrect = e.Exception is MqttCommunicationException c && c.InnerException is SocketException;
ex = e.Exception;
};

try
@@ -32,8 +35,42 @@ namespace MQTTnet.Core.Tests
catch
{
}

Assert.IsNotNull(ex);
Assert.IsInstanceOfType(ex, typeof(MqttCommunicationException));
Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException));
}

[TestMethod]
public async Task ClientCleanupOnAuthentificationFails()
{
var channel = new TestMqttCommunicationAdapter();
var channel2 = new TestMqttCommunicationAdapter();
channel.Partner = channel2;
channel2.Partner = channel;

Task.Run(async () => {
var connect = await channel2.ReceivePacketAsync(TimeSpan.Zero, CancellationToken.None);
await channel2.SendPacketAsync(new MqttConnAckPacket() { ConnectReturnCode = Protocol.MqttConnectReturnCode.ConnectionRefusedNotAuthorized }, CancellationToken.None);
});


Assert.IsTrue(exceptionIsCorrect);
var fake = new TestMqttCommunicationAdapterFactory(channel);

var client = new MqttClient(fake, new MqttNetLogger());
try
{
await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("any-server").Build());
}
catch (Exception ex)
{
Assert.IsInstanceOfType(ex, typeof(MqttConnectingFailedException));
}

Assert.IsTrue(client._packetReceiverTask == null || client._packetReceiverTask.IsCompleted, "receive loop not completed");
Assert.IsTrue(client._keepAliveMessageSenderTask == null || client._keepAliveMessageSenderTask.IsCompleted, "keepalive loop not completed");
}
}
}

+ 68
- 0
Tests/MQTTnet.Core.Tests/MqttServerTests.cs Parādīt failu

@@ -233,6 +233,74 @@ namespace MQTTnet.Core.Tests
await c1.PublishAsync(message);
}
}
[TestMethod]
public async Task MqttServer_ShutdownDisconnectsClientsGracefully()
{
var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger());
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();

var disconnectCalled = 0;

await s.StartAsync(new MqttServerOptions());

var c1 = new MqttFactory().CreateMqttClient();
c1.Disconnected += (sender, args) => disconnectCalled++;

await c1.ConnectAsync(clientOptions);

await Task.Delay(100);

await s.StopAsync();

await Task.Delay(100);

Assert.AreEqual(1, disconnectCalled);
}

[TestMethod]
public async Task MqttServer_HandleCleanDisconnect()
{
MqttNetGlobalLogger.LogMessagePublished += (_, e) =>
{
System.Diagnostics.Debug.WriteLine($"[{e.TraceMessage.Timestamp:s}] {e.TraceMessage.Source} {e.TraceMessage.Message}");
};

var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger());
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());

var clientConnectedCalled = 0;
var clientDisconnectedCalled = 0;

s.ClientConnected += (_, __) => clientConnectedCalled++;
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++;

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();

await s.StartAsync(new MqttServerOptions());

var c1 = new MqttFactory().CreateMqttClient();

await c1.ConnectAsync(clientOptions);

await Task.Delay(100);

await c1.DisconnectAsync();

await Task.Delay(100);

await s.StopAsync();

await Task.Delay(100);

Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
}

[TestMethod]
public async Task MqttServer_RetainedMessagesFlow()


Notiek ielāde…
Atcelt
Saglabāt