Browse Source

Refactor interceptors etc. to support async.

release/3.x.x
Christian Kratky 6 years ago
parent
commit
c0507fcc55
33 changed files with 542 additions and 316 deletions
  1. +2
    -2
      Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs
  2. +1
    -1
      Source/MQTTnet/Formatter/MqttPacketReader.cs
  3. +16
    -10
      Source/MQTTnet/Formatter/MqttPacketWriter.cs
  4. +2
    -2
      Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs
  5. +2
    -2
      Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs
  6. +24
    -10
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  7. +16
    -7
      Source/MQTTnet/Internal/AsyncLock.cs
  8. +2
    -1
      Source/MQTTnet/Server/IMqttClientSession.cs
  9. +0
    -3
      Source/MQTTnet/Server/IMqttServer.cs
  10. +9
    -0
      Source/MQTTnet/Server/IMqttServerApplicationMessageInterceptor.cs
  11. +9
    -0
      Source/MQTTnet/Server/IMqttServerClientMessageQueueInterceptor.cs
  12. +9
    -0
      Source/MQTTnet/Server/IMqttServerConnectionValidator.cs
  13. +4
    -4
      Source/MQTTnet/Server/IMqttServerOptions.cs
  14. +9
    -0
      Source/MQTTnet/Server/IMqttServerSubscriptionInterceptor.cs
  15. +1
    -1
      Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs
  16. +2
    -2
      Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs
  17. +107
    -110
      Source/MQTTnet/Server/MqttClientSession.cs
  18. +4
    -7
      Source/MQTTnet/Server/MqttClientSessionStatus.cs
  19. +46
    -50
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  20. +11
    -6
      Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs
  21. +21
    -28
      Source/MQTTnet/Server/MqttRetainedMessagesManager.cs
  22. +3
    -8
      Source/MQTTnet/Server/MqttServer.cs
  23. +31
    -0
      Source/MQTTnet/Server/MqttServerApplicationMessageInterceptorDelegate.cs
  24. +31
    -0
      Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs
  25. +5
    -5
      Source/MQTTnet/Server/MqttServerOptions.cs
  26. +21
    -3
      Source/MQTTnet/Server/MqttServerOptionsBuilder.cs
  27. +31
    -0
      Source/MQTTnet/Server/MqttServerSubscriptionInterceptorDelegate.cs
  28. +1
    -1
      Tests/MQTTnet.Core.Tests/AsyncLockTests.cs
  29. +2
    -1
      Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs
  30. +101
    -34
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs
  31. +5
    -5
      Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs
  32. +7
    -6
      Tests/MQTTnet.TestApp.NetCore/ServerTest.cs
  33. +7
    -7
      Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs

+ 2
- 2
Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs View File

@@ -30,7 +30,7 @@ namespace MQTTnet.Extensions.ManagedClient
{ {
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));


using (await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
using (await _messagesLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
_messages.Add(applicationMessage); _messages.Add(applicationMessage);
await SaveAsync().ConfigureAwait(false); await SaveAsync().ConfigureAwait(false);
@@ -41,7 +41,7 @@ namespace MQTTnet.Extensions.ManagedClient
{ {
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));


using (await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
using (await _messagesLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{ {
var index = _messages.IndexOf(applicationMessage); var index = _messages.IndexOf(applicationMessage);
if (index == -1) if (index == -1)


+ 1
- 1
Source/MQTTnet/Formatter/MqttPacketReader.cs View File

@@ -63,7 +63,7 @@ namespace MQTTnet.Formatter
{ {
var offset = 0; var offset = 0;
var multiplier = 128; var multiplier = 128;
var value = (initialEncodedByte & 127);
var value = initialEncodedByte & 127;
int encodedByte = initialEncodedByte; int encodedByte = initialEncodedByte;


while ((encodedByte & 128) != 0) while ((encodedByte & 128) != 0)


+ 16
- 10
Source/MQTTnet/Formatter/MqttPacketWriter.cs View File

@@ -15,6 +15,7 @@ namespace MQTTnet.Formatter
{ {
public static int MaxBufferSize = 4096; public static int MaxBufferSize = 4096;


// TODO: Consider using the ArrayPool here together with FreeBuffer.
private byte[] _buffer = new byte[128]; private byte[] _buffer = new byte[128];


private int _offset; private int _offset;
@@ -28,13 +29,18 @@ namespace MQTTnet.Formatter
return (byte)fixedHeader; return (byte)fixedHeader;
} }


public static ArraySegment<byte> EncodeVariableByteInteger(uint value)
public static ArraySegment<byte> EncodeVariableLengthInteger(uint value)
{ {
if (value <= 0)
if (value == 0)
{ {
return new ArraySegment<byte>(new byte[1], 0, 1); return new ArraySegment<byte>(new byte[1], 0, 1);
} }


if (value <= 127)
{
return new ArraySegment<byte>(new[] { (byte)value }, 0, 1);
}

var buffer = new byte[4]; var buffer = new byte[4];
var bufferOffset = 0; var bufferOffset = 0;


@@ -57,7 +63,7 @@ namespace MQTTnet.Formatter


public void WriteVariableLengthInteger(uint value) public void WriteVariableLengthInteger(uint value)
{ {
Write(EncodeVariableByteInteger(value));
Write(EncodeVariableLengthInteger(value));
} }


public void WriteWithLengthPrefix(string value) public void WriteWithLengthPrefix(string value)
@@ -80,7 +86,7 @@ namespace MQTTnet.Formatter
EnsureAdditionalCapacity(1); EnsureAdditionalCapacity(1);


_buffer[_offset] = @byte; _buffer[_offset] = @byte;
IncreasePostition(1);
IncreasePosition(1);
} }


public void Write(ushort value) public void Write(ushort value)
@@ -88,9 +94,9 @@ namespace MQTTnet.Formatter
EnsureAdditionalCapacity(2); EnsureAdditionalCapacity(2);


_buffer[_offset] = (byte)(value >> 8); _buffer[_offset] = (byte)(value >> 8);
IncreasePostition(1);
IncreasePosition(1);
_buffer[_offset] = (byte)value; _buffer[_offset] = (byte)value;
IncreasePostition(1);
IncreasePosition(1);
} }


public void Write(byte[] buffer, int offset, int count) public void Write(byte[] buffer, int offset, int count)
@@ -100,7 +106,7 @@ namespace MQTTnet.Formatter
EnsureAdditionalCapacity(count); EnsureAdditionalCapacity(count);


Array.Copy(buffer, offset, _buffer, _offset, count); Array.Copy(buffer, offset, _buffer, _offset, count);
IncreasePostition(count);
IncreasePosition(count);
} }


public void Write(ArraySegment<byte> buffer) public void Write(ArraySegment<byte> buffer)
@@ -122,9 +128,9 @@ namespace MQTTnet.Formatter
Write(propertyWriter._buffer, 0, propertyWriter.Length); Write(propertyWriter._buffer, 0, propertyWriter.Length);
} }


public void Reset()
public void Reset(int length)
{ {
Length = 5;
Length = length;
} }


public void Seek(int position) public void Seek(int position)
@@ -185,7 +191,7 @@ namespace MQTTnet.Formatter
} }


[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private void IncreasePostition(int length)
private void IncreasePosition(int length)
{ {
_offset += length; _offset += length;




+ 2
- 2
Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs View File

@@ -21,13 +21,13 @@ namespace MQTTnet.Formatter.V3
if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packet == null) throw new ArgumentNullException(nameof(packet));


// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes)
_packetWriter.Reset();
_packetWriter.Reset(5);
_packetWriter.Seek(5); _packetWriter.Seek(5);


var fixedHeader = EncodePacket(packet, _packetWriter); var fixedHeader = EncodePacket(packet, _packetWriter);
var remainingLength = (uint)(_packetWriter.Length - 5); var remainingLength = (uint)(_packetWriter.Length - 5);


var remainingLengthBuffer = MqttPacketWriter.EncodeVariableByteInteger(remainingLength);
var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength);


var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; var headerSize = FixedHeaderSize + remainingLengthBuffer.Count;
var headerOffset = 5 - headerSize; var headerOffset = 5 - headerSize;


+ 2
- 2
Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs View File

@@ -15,13 +15,13 @@ namespace MQTTnet.Formatter.V5
if (packet == null) throw new ArgumentNullException(nameof(packet)); if (packet == null) throw new ArgumentNullException(nameof(packet));


// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes)
_packetWriter.Reset();
_packetWriter.Reset(5);
_packetWriter.Seek(5); _packetWriter.Seek(5);


var fixedHeader = EncodePacket(packet, _packetWriter); var fixedHeader = EncodePacket(packet, _packetWriter);
var remainingLength = (uint)(_packetWriter.Length - 5); var remainingLength = (uint)(_packetWriter.Length - 5);


var remainingLengthBuffer = MqttPacketWriter.EncodeVariableByteInteger(remainingLength);
var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength);


var headerSize = 1 + remainingLengthBuffer.Count; var headerSize = 1 + remainingLengthBuffer.Count;
var headerOffset = 5 - headerSize; var headerOffset = 5 - headerSize;


+ 24
- 10
Source/MQTTnet/Implementations/MqttTcpChannel.cs View File

@@ -52,11 +52,15 @@ namespace MQTTnet.Implementations
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; _socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
} }


// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
#if NET452 || NET461 #if NET452 || NET461
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false);
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false);
#else #else
await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false);
await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false);
#endif #endif
}


SslStream sslStream = null; SslStream sslStream = null;
if (_options.TlsOptions.UseTls) if (_options.TlsOptions.UseTls)
@@ -74,14 +78,23 @@ namespace MQTTnet.Implementations
return Task.FromResult(0); return Task.FromResult(0);
} }


public Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{ {
return _stream.ReadAsync(buffer, offset, count, cancellationToken);
// Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
} }


public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{ {
return _stream.WriteAsync(buffer, offset, count, cancellationToken);
// Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
await _stream.FlushAsync(cancellationToken);
}
} }


public void Dispose() public void Dispose()
@@ -89,10 +102,11 @@ namespace MQTTnet.Implementations
Cleanup(ref _stream, s => s.Dispose()); Cleanup(ref _stream, s => s.Dispose());
Cleanup(ref _socket, s => Cleanup(ref _socket, s =>
{ {
if (s.Connected)
{
s.Shutdown(SocketShutdown.Both);
}
//if (s.Connected)
//{
// s.Shutdown(SocketShutdown.Both);
//}

s.Dispose(); s.Dispose();
}); });
} }


+ 16
- 7
Source/MQTTnet/Internal/AsyncLock.cs View File

@@ -15,14 +15,23 @@ namespace MQTTnet.Internal
_releaser = Task.FromResult((IDisposable)new Releaser(this)); _releaser = Task.FromResult((IDisposable)new Releaser(this));
} }


public Task<IDisposable> LockAsync(CancellationToken cancellationToken)
public Task<IDisposable> WaitAsync()
{ {
var wait = _semaphore.WaitAsync(cancellationToken);
return wait.IsCompleted ?
_releaser :
wait.ContinueWith((_, state) => (IDisposable)state,
_releaser.Result, cancellationToken,
TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
return WaitAsync(CancellationToken.None);
}

public Task<IDisposable> WaitAsync(CancellationToken cancellationToken)
{
var task = _semaphore.WaitAsync(cancellationToken);
if (task.IsCompleted)
{
return _releaser;
}

return task.ContinueWith(
(_, state) => (IDisposable)state,
_releaser.Result,
cancellationToken, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
} }


public void Dispose() public void Dispose()


+ 2
- 1
Source/MQTTnet/Server/IMqttClientSession.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Threading.Tasks;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
@@ -6,6 +7,6 @@ namespace MQTTnet.Server
{ {
string ClientId { get; } string ClientId { get; }


void Stop(MqttClientDisconnectType disconnectType);
Task StopAsync(MqttClientDisconnectType disconnectType);
} }
} }

+ 0
- 3
Source/MQTTnet/Server/IMqttServer.cs View File

@@ -16,11 +16,8 @@ namespace MQTTnet.Server
IMqttServerOptions Options { get; } IMqttServerOptions Options { get; }


[Obsolete("This method is no longer async. Use the not async method.")]
Task<IList<IMqttClientSessionStatus>> GetClientSessionsStatusAsync(); Task<IList<IMqttClientSessionStatus>> GetClientSessionsStatusAsync();


IList<IMqttClientSessionStatus> GetClientSessionsStatus();

IList<MqttApplicationMessage> GetRetainedMessages(); IList<MqttApplicationMessage> GetRetainedMessages();
Task ClearRetainedMessagesAsync(); Task ClearRetainedMessagesAsync();




+ 9
- 0
Source/MQTTnet/Server/IMqttServerApplicationMessageInterceptor.cs View File

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerApplicationMessageInterceptor
{
Task InterceptApplicationMessagePublishAsync(MqttApplicationMessageInterceptorContext context);
}
}

+ 9
- 0
Source/MQTTnet/Server/IMqttServerClientMessageQueueInterceptor.cs View File

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerClientMessageQueueInterceptor
{
Task InterceptClientMessageQueueEnqueueAsync(MqttClientMessageQueueInterceptorContext context);
}
}

+ 9
- 0
Source/MQTTnet/Server/IMqttServerConnectionValidator.cs View File

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerConnectionValidator
{
Task ValidateConnection(MqttConnectionValidatorContext context);
}
}

+ 4
- 4
Source/MQTTnet/Server/IMqttServerOptions.cs View File

@@ -11,10 +11,10 @@ namespace MQTTnet.Server


TimeSpan DefaultCommunicationTimeout { get; } TimeSpan DefaultCommunicationTimeout { get; }


Action<MqttConnectionValidatorContext> ConnectionValidator { get; }
Action<MqttSubscriptionInterceptorContext> SubscriptionInterceptor { get; }
Action<MqttApplicationMessageInterceptorContext> ApplicationMessageInterceptor { get; }
Action<MqttClientMessageQueueInterceptorContext> ClientMessageQueueInterceptor { get; }
IMqttServerConnectionValidator ConnectionValidator { get; }
IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; }
IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; }
IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; }


MqttServerTcpEndpointOptions DefaultEndpointOptions { get; } MqttServerTcpEndpointOptions DefaultEndpointOptions { get; }
MqttServerTlsTcpEndpointOptions TlsEndpointOptions { get; } MqttServerTlsTcpEndpointOptions TlsEndpointOptions { get; }


+ 9
- 0
Source/MQTTnet/Server/IMqttServerSubscriptionInterceptor.cs View File

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerSubscriptionInterceptor
{
Task InterceptSubscriptionAsync(MqttSubscriptionInterceptorContext context);
}
}

+ 1
- 1
Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs View File

@@ -81,7 +81,7 @@ namespace MQTTnet.Server
if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds >= keepAlivePeriod * 1.5D) if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds >= keepAlivePeriod * 1.5D)
{ {
_logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientSession.ClientId); _logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientSession.ClientId);
_clientSession.Stop(MqttClientDisconnectType.NotClean);
await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);


return; return;
} }


+ 2
- 2
Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs View File

@@ -133,7 +133,7 @@ namespace MQTTnet.Server
return; return;
} }


adapter.SendPacketAsync(packet, cancellationToken).GetAwaiter().GetResult();
await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false);


_logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId);
} }
@@ -167,7 +167,7 @@ namespace MQTTnet.Server


if (!cancellationToken.IsCancellationRequested) if (!cancellationToken.IsCancellationRequested)
{ {
_clientSession.Stop(MqttClientDisconnectType.NotClean);
await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
} }
} }
} }


+ 107
- 110
Source/MQTTnet/Server/MqttClientSession.cs View File

@@ -79,22 +79,20 @@ namespace MQTTnet.Server
return _workerTask; return _workerTask;
} }


public void Stop(MqttClientDisconnectType type)
public Task StopAsync(MqttClientDisconnectType type)
{ {
Stop(type, false);
return StopAsync(type, false);
} }


public Task SubscribeAsync(IList<TopicFilter> topicFilters)
public async Task SubscribeAsync(IList<TopicFilter> topicFilters)
{ {
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));


var packet = new MqttSubscribePacket(); var packet = new MqttSubscribePacket();
packet.TopicFilters.AddRange(topicFilters); packet.TopicFilters.AddRange(topicFilters);


_subscriptionsManager.Subscribe(packet);

EnqueueSubscribedRetainedMessages(topicFilters);
return Task.FromResult(0);
await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false);
await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false);
} }


public Task UnsubscribeAsync(IList<string> topicFilters) public Task UnsubscribeAsync(IList<string> topicFilters)
@@ -122,7 +120,7 @@ namespace MQTTnet.Server
_cancellationTokenSource?.Dispose(); _cancellationTokenSource?.Dispose();
_cancellationTokenSource = null; _cancellationTokenSource = null;
} }
private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter)
{ {
if (adapter == null) throw new ArgumentNullException(nameof(adapter)); if (adapter == null) throw new ArgumentNullException(nameof(adapter));
@@ -131,7 +129,7 @@ namespace MQTTnet.Server
{ {
if (_cancellationTokenSource != null) if (_cancellationTokenSource != null)
{ {
Stop(MqttClientDisconnectType.Clean, true);
await StopAsync(MqttClientDisconnectType.Clean, true).ConfigureAwait(false);
} }


adapter.ReadingPacketStarted += OnAdapterReadingPacketStarted; adapter.ReadingPacketStarted += OnAdapterReadingPacketStarted;
@@ -139,15 +137,15 @@ namespace MQTTnet.Server


_cancellationTokenSource = new CancellationTokenSource(); _cancellationTokenSource = new CancellationTokenSource();


//workaround for https://github.com/dotnet/corefx/issues/24430
#pragma warning disable 4014
_cleanupHandle = _cancellationTokenSource.Token.Register(async () =>
{
await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
TryDisposeAdapter(adapter);
});
#pragma warning restore 4014
//end workaround
// //workaround for https://github.com/dotnet/corefx/issues/24430
//#pragma warning disable 4014
// _cleanupHandle = _cancellationTokenSource.Token.Register(async () =>
// {
// await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
// TryDisposeAdapter(adapter);
// });
//#pragma warning restore 4014
// //end workaround


_wasCleanDisconnect = false; _wasCleanDisconnect = false;
_willMessage = connectPacket.WillMessage; _willMessage = connectPacket.WillMessage;
@@ -165,7 +163,7 @@ namespace MQTTnet.Server
if (packet != null) if (packet != null)
{ {
_keepAliveMonitor.PacketReceived(packet); _keepAliveMonitor.PacketReceived(packet);
ProcessReceivedPacket(adapter, packet, _cancellationTokenSource.Token);
await ProcessReceivedPacketAsync(adapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false);
} }
} }
} }
@@ -190,51 +188,55 @@ namespace MQTTnet.Server
_logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId);
} }


Stop(MqttClientDisconnectType.NotClean, true);
await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
} }
finally finally
{ {
adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;

_adapterEndpoint = null; _adapterEndpoint = null;
_adapterProtocolVersion = null; _adapterProtocolVersion = null;


// Uncomment as soon as the workaround above is no longer needed. // Uncomment as soon as the workaround above is no longer needed.
// Also called in outer scope!
//await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false); //await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
//TryDisposeAdapter(adapter); //TryDisposeAdapter(adapter);


_cleanupHandle?.Dispose(); _cleanupHandle?.Dispose();
_cleanupHandle = null; _cleanupHandle = null;
_cancellationTokenSource?.Cancel(false); _cancellationTokenSource?.Cancel(false);
_cancellationTokenSource?.Dispose(); _cancellationTokenSource?.Dispose();
_cancellationTokenSource = null; _cancellationTokenSource = null;
} }
} }


private void TryDisposeAdapter(IMqttChannelAdapter adapter)
{
if (adapter == null)
{
return;
}
try
{
adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
adapter.Dispose();
}
catch (Exception exception)
{
_logger.Error(exception, exception.Message);
}
finally
{
adapter.Dispose();
}
}
private void Stop(MqttClientDisconnectType type, bool isInsideSession)
////private void TryDisposeAdapter(IMqttChannelAdapter adapter)
////{
//// if (adapter == null)
//// {
//// return;
//// }
//// try
//// {
//// adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
//// adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
//// adapter.Dispose();
//// }
//// catch (Exception exception)
//// {
//// _logger.Error(exception, exception.Message);
//// }
//// finally
//// {
//// adapter.Dispose();
//// }
////}
private async Task StopAsync(MqttClientDisconnectType type, bool isInsideSession)
{ {
try try
{ {
@@ -257,7 +259,10 @@ namespace MQTTnet.Server


if (!isInsideSession) if (!isInsideSession)
{ {
_workerTask?.GetAwaiter().GetResult();
if (_workerTask != null)
{
await _workerTask.ConfigureAwait(false);
}
} }
} }
finally finally
@@ -267,7 +272,7 @@ namespace MQTTnet.Server
} }
} }


public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage, bool isRetainedApplicationMessage)
public async Task EnqueueApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage, bool isRetainedApplicationMessage)
{ {
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));


@@ -278,10 +283,10 @@ namespace MQTTnet.Server
} }


var publishPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreatePublishPacket(applicationMessage); var publishPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreatePublishPacket(applicationMessage);
// Set the retain flag to true according to [MQTT-3.3.1-8] and [MQTT-3.3.1-9]. // Set the retain flag to true according to [MQTT-3.3.1-8] and [MQTT-3.3.1-9].
publishPacket.Retain = isRetainedApplicationMessage; publishPacket.Retain = isRetainedApplicationMessage;
if (publishPacket.QualityOfServiceLevel > 0) if (publishPacket.QualityOfServiceLevel > 0)
{ {
publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier();
@@ -294,7 +299,10 @@ namespace MQTTnet.Server
ClientId, ClientId,
applicationMessage); applicationMessage);


_options.ClientMessageQueueInterceptor?.Invoke(context);
if (_options.ClientMessageQueueInterceptor != null)
{
await _options.ClientMessageQueueInterceptor.InterceptClientMessageQueueEnqueueAsync(context).ConfigureAwait(false);
}


if (!context.AcceptEnqueue || context.ApplicationMessage == null) if (!context.AcceptEnqueue || context.ApplicationMessage == null)
{ {
@@ -309,35 +317,33 @@ namespace MQTTnet.Server
_pendingPacketsQueue.Enqueue(publishPacket); _pendingPacketsQueue.Enqueue(publishPacket);
} }


private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter)
{
if (adapter == null)
{
return;
}
try
{
await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while disconnecting channel adapter.");
}
}
private void ProcessReceivedPacket(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
//private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter)
//{
// if (adapter == null)
// {
// return;
// }
// try
// {
// await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
// }
// catch (Exception exception)
// {
// _logger.Error(exception, "Error while disconnecting channel adapter.");
// }
//}
private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
{ {
if (packet is MqttPublishPacket publishPacket) if (packet is MqttPublishPacket publishPacket)
{ {
HandleIncomingPublishPacket(adapter, publishPacket, cancellationToken);
return;
return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken);
} }


if (packet is MqttPingReqPacket) if (packet is MqttPingReqPacket)
{ {
adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken).GetAwaiter().GetResult();
return;
return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
} }


if (packet is MqttPubRelPacket pubRelPacket) if (packet is MqttPubRelPacket pubRelPacket)
@@ -348,8 +354,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubCompReasonCode.Success ReasonCode = MqttPubCompReasonCode.Success
}; };


adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult();
return;
return adapter.SendPacketAsync(responsePacket, cancellationToken);
} }


if (packet is MqttPubRecPacket pubRecPacket) if (packet is MqttPubRecPacket pubRecPacket)
@@ -360,91 +365,83 @@ namespace MQTTnet.Server
ReasonCode = MqttPubRelReasonCode.Success ReasonCode = MqttPubRelReasonCode.Success
}; };


adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult();
return;
return adapter.SendPacketAsync(responsePacket, cancellationToken);
} }


if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
{ {
return;
return Task.FromResult(0);
} }


if (packet is MqttSubscribePacket subscribePacket) if (packet is MqttSubscribePacket subscribePacket)
{ {
HandleIncomingSubscribePacket(adapter, subscribePacket, cancellationToken);
return;
return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken);
} }


if (packet is MqttUnsubscribePacket unsubscribePacket) if (packet is MqttUnsubscribePacket unsubscribePacket)
{ {
HandleIncomingUnsubscribePacket(adapter, unsubscribePacket, cancellationToken);
return;
return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken);
} }


if (packet is MqttDisconnectPacket) if (packet is MqttDisconnectPacket)
{ {
Stop(MqttClientDisconnectType.Clean, true);
return;
return StopAsync(MqttClientDisconnectType.Clean, true);
} }


if (packet is MqttConnectPacket) if (packet is MqttConnectPacket)
{ {
Stop(MqttClientDisconnectType.NotClean, true);
return;
return StopAsync(MqttClientDisconnectType.NotClean, true);
} }


_logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); _logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet);


Stop(MqttClientDisconnectType.NotClean, true);
return StopAsync(MqttClientDisconnectType.NotClean, true);
} }


private void EnqueueSubscribedRetainedMessages(ICollection<TopicFilter> topicFilters)
private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
{ {
var retainedMessages = _retainedMessagesManager.GetSubscribedMessages(topicFilters);
var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false);
foreach (var applicationMessage in retainedMessages) foreach (var applicationMessage in retainedMessages)
{ {
EnqueueApplicationMessage(null, applicationMessage, true);
await EnqueueApplicationMessageAsync(null, applicationMessage, true).ConfigureAwait(false);
} }
} }


private void HandleIncomingSubscribePacket(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
{ {
var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket);
adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).GetAwaiter().GetResult();
var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false);
await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false);


if (subscribeResult.CloseConnection) if (subscribeResult.CloseConnection)
{ {
Stop(MqttClientDisconnectType.NotClean, true);
return;
await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
} }


EnqueueSubscribedRetainedMessages(subscribePacket.TopicFilters);
await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false);
} }


private void HandleIncomingUnsubscribePacket(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
{ {
var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket);
adapter.SendPacketAsync(unsubscribeResult, cancellationToken).GetAwaiter().GetResult();
return adapter.SendPacketAsync(unsubscribeResult, cancellationToken);
} }


private void HandleIncomingPublishPacket(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{ {
switch (publishPacket.QualityOfServiceLevel) switch (publishPacket.QualityOfServiceLevel)
{ {
case MqttQualityOfServiceLevel.AtMostOnce: case MqttQualityOfServiceLevel.AtMostOnce:
{ {
HandleIncomingPublishPacketWithQoS0(publishPacket); HandleIncomingPublishPacketWithQoS0(publishPacket);
break;
return Task.FromResult(0);
} }
case MqttQualityOfServiceLevel.AtLeastOnce: case MqttQualityOfServiceLevel.AtLeastOnce:
{ {
HandleIncomingPublishPacketWithQoS1(adapter, publishPacket, cancellationToken);
break;
return HandleIncomingPublishPacketWithQoS1Async(adapter, publishPacket, cancellationToken);
} }
case MqttQualityOfServiceLevel.ExactlyOnce: case MqttQualityOfServiceLevel.ExactlyOnce:
{ {
HandleIncomingPublishPacketWithQoS2(adapter, publishPacket, cancellationToken);
break;
return HandleIncomingPublishPacketWithQoS2Async(adapter, publishPacket, cancellationToken);
} }
default: default:
{ {
@@ -456,17 +453,17 @@ namespace MQTTnet.Server
private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket) private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket)
{ {
_sessionsManager.EnqueueApplicationMessage( _sessionsManager.EnqueueApplicationMessage(
this,
this,
_channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket));
} }


private void HandleIncomingPublishPacketWithQoS1(
private Task HandleIncomingPublishPacketWithQoS1Async(
IMqttChannelAdapter adapter, IMqttChannelAdapter adapter,
MqttPublishPacket publishPacket, MqttPublishPacket publishPacket,
CancellationToken cancellationToken) CancellationToken cancellationToken)
{ {
_sessionsManager.EnqueueApplicationMessage( _sessionsManager.EnqueueApplicationMessage(
this,
this,
_channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket));


var response = new MqttPubAckPacket var response = new MqttPubAckPacket
@@ -475,10 +472,10 @@ namespace MQTTnet.Server
ReasonCode = MqttPubAckReasonCode.Success ReasonCode = MqttPubAckReasonCode.Success
}; };


adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult();
return adapter.SendPacketAsync(response, cancellationToken);
} }


private void HandleIncomingPublishPacketWithQoS2(
private Task HandleIncomingPublishPacketWithQoS2Async(
IMqttChannelAdapter adapter, IMqttChannelAdapter adapter,
MqttPublishPacket publishPacket, MqttPublishPacket publishPacket,
CancellationToken cancellationToken) CancellationToken cancellationToken)
@@ -492,7 +489,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubRecReasonCode.Success ReasonCode = MqttPubRecReasonCode.Success
}; };


adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult();
return adapter.SendPacketAsync(response, cancellationToken);
} }


private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) private void OnAdapterReadingPacketCompleted(object sender, EventArgs e)


+ 4
- 7
Source/MQTTnet/Server/MqttClientSessionStatus.cs View File

@@ -25,22 +25,19 @@ namespace MQTTnet.Server


public Task DisconnectAsync() public Task DisconnectAsync()
{ {
_session.Stop(MqttClientDisconnectType.NotClean);
return Task.FromResult(0);
return _session.StopAsync(MqttClientDisconnectType.NotClean);
} }


public Task DeleteSessionAsync()
public async Task DeleteSessionAsync()
{ {
try try
{ {
_session.Stop(MqttClientDisconnectType.NotClean);
await _session.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
} }
finally finally
{ {
_sessionsManager.DeleteSession(ClientId);
await _sessionsManager.DeleteSessionAsync(ClientId).ConfigureAwait(false);
} }

return Task.FromResult(0);
} }


public Task ClearPendingApplicationMessagesAsync() public Task ClearPendingApplicationMessagesAsync()


+ 46
- 50
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -1,12 +1,11 @@
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Adapter; using MQTTnet.Adapter;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets; using MQTTnet.Packets;
using MQTTnet.Protocol; using MQTTnet.Protocol;


@@ -16,9 +15,7 @@ namespace MQTTnet.Server
{ {
private readonly BlockingCollection<MqttEnqueuedApplicationMessage> _messageQueue = new BlockingCollection<MqttEnqueuedApplicationMessage>(); private readonly BlockingCollection<MqttEnqueuedApplicationMessage> _messageQueue = new BlockingCollection<MqttEnqueuedApplicationMessage>();


/// <summary>
/// manual locking dictionaries is faster than using concurrent dictionary
/// </summary>
private readonly AsyncLock _sessionsLock = new AsyncLock();
private readonly Dictionary<string, MqttClientSession> _sessions = new Dictionary<string, MqttClientSession>(); private readonly Dictionary<string, MqttClientSession> _sessions = new Dictionary<string, MqttClientSession>();


private readonly CancellationToken _cancellationToken; private readonly CancellationToken _cancellationToken;
@@ -47,16 +44,16 @@ namespace MQTTnet.Server


public void Start() public void Start()
{ {
Task.Factory.StartNew(() => TryProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default);
Task.Factory.StartNew(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default);
} }


public void Stop()
public async Task StopAsync()
{ {
lock (_sessions)
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{ {
foreach (var session in _sessions) foreach (var session in _sessions)
{ {
session.Value.Stop(MqttClientDisconnectType.NotClean);
await session.Value.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
} }


_sessions.Clear(); _sessions.Clear();
@@ -68,18 +65,21 @@ namespace MQTTnet.Server
return Task.Run(() => RunSessionAsync(clientAdapter, _cancellationToken), _cancellationToken); return Task.Run(() => RunSessionAsync(clientAdapter, _cancellationToken), _cancellationToken);
} }


public IList<IMqttClientSessionStatus> GetClientStatus()
public async Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync()
{ {
var result = new List<IMqttClientSessionStatus>(); var result = new List<IMqttClientSessionStatus>();


foreach (var session in GetSessions())
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{ {
var status = new MqttClientSessionStatus(this, session);
session.FillStatus(status);
foreach (var session in _sessions.Values)
{
var status = new MqttClientSessionStatus(this, session);
session.FillStatus(status);


result.Add(status);
result.Add(status);
}
} }
return result; return result;
} }


@@ -122,9 +122,9 @@ namespace MQTTnet.Server
} }
} }


public void DeleteSession(string clientId)
public async Task DeleteSessionAsync(string clientId)
{ {
lock (_sessions)
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{ {
_sessions.Remove(clientId); _sessions.Remove(clientId);
} }
@@ -137,13 +137,13 @@ namespace MQTTnet.Server
_messageQueue?.Dispose(); _messageQueue?.Dispose();
} }


private void TryProcessQueuedApplicationMessages(CancellationToken cancellationToken)
private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken)
{ {
while (!cancellationToken.IsCancellationRequested) while (!cancellationToken.IsCancellationRequested)
{ {
try try
{ {
TryProcessNextQueuedApplicationMessage(cancellationToken);
await TryProcessNextQueuedApplicationMessageAsync(cancellationToken).ConfigureAwait(false);
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
@@ -155,7 +155,7 @@ namespace MQTTnet.Server
} }
} }


private void TryProcessNextQueuedApplicationMessage(CancellationToken cancellationToken)
private async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken)
{ {
try try
{ {
@@ -164,12 +164,12 @@ namespace MQTTnet.Server
var sender = enqueuedApplicationMessage.Sender; var sender = enqueuedApplicationMessage.Sender;
var applicationMessage = enqueuedApplicationMessage.ApplicationMessage; var applicationMessage = enqueuedApplicationMessage.ApplicationMessage;


var interceptorContext = InterceptApplicationMessage(sender, applicationMessage);
var interceptorContext = await InterceptApplicationMessageAsync(sender, applicationMessage).ConfigureAwait(false);
if (interceptorContext != null) if (interceptorContext != null)
{ {
if (interceptorContext.CloseConnection) if (interceptorContext.CloseConnection)
{ {
enqueuedApplicationMessage.Sender.Stop(MqttClientDisconnectType.NotClean);
await enqueuedApplicationMessage.Sender.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
} }


if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish)
@@ -184,12 +184,18 @@ namespace MQTTnet.Server


if (applicationMessage.Retain) if (applicationMessage.Retain)
{ {
_retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult();
await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
} }


foreach (var clientSession in GetSessions())
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{ {
clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, enqueuedApplicationMessage.ApplicationMessage, false);
foreach (var clientSession in _sessions.Values)
{
await clientSession.EnqueueApplicationMessageAsync(
enqueuedApplicationMessage.Sender,
enqueuedApplicationMessage.ApplicationMessage,
false).ConfigureAwait(false);
}
} }
} }
catch (OperationCanceledException) catch (OperationCanceledException)
@@ -201,35 +207,22 @@ namespace MQTTnet.Server
} }
} }


private List<MqttClientSession> GetSessions()
{
lock (_sessions)
{
return _sessions.Values.ToList();
}
}

private async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) private async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
{ {
var clientId = string.Empty; var clientId = string.Empty;
try try
{ {
// TODO: Catch cancel exception here if the first packet was not received and log properly.
var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
if (firstPacket == null)
{
return;
}

if (!(firstPacket is MqttConnectPacket connectPacket)) if (!(firstPacket is MqttConnectPacket connectPacket))
{ {
throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1].");
_logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", clientAdapter.Endpoint);
return;
} }


clientId = connectPacket.ClientId; clientId = connectPacket.ClientId;


var connectReturnCode = ValidateConnection(connectPacket, clientAdapter);
var connectReturnCode = await ValidateConnectionAsync(connectPacket, clientAdapter).ConfigureAwait(false);
if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
{ {
await clientAdapter.SendPacketAsync( await clientAdapter.SendPacketAsync(
@@ -242,7 +235,7 @@ namespace MQTTnet.Server
return; return;
} }


var result = PrepareClientSession(connectPacket);
var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);


await clientAdapter.SendPacketAsync( await clientAdapter.SendPacketAsync(
new MqttConnAckPacket new MqttConnAckPacket
@@ -267,14 +260,17 @@ namespace MQTTnet.Server
} }
finally finally
{ {
await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
clientAdapter.Dispose();

if (!_options.EnablePersistentSessions) if (!_options.EnablePersistentSessions)
{ {
DeleteSession(clientId);
await DeleteSessionAsync(clientId).ConfigureAwait(false);
} }
} }
} }


private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter)
private async Task<MqttConnectReturnCode> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter)
{ {
if (_options.ConnectionValidator == null) if (_options.ConnectionValidator == null)
{ {
@@ -288,13 +284,13 @@ namespace MQTTnet.Server
connectPacket.WillMessage, connectPacket.WillMessage,
clientAdapter.Endpoint); clientAdapter.Endpoint);


_options.ConnectionValidator(context);
await _options.ConnectionValidator.ValidateConnection(context).ConfigureAwait(false);
return context.ReturnCode; return context.ReturnCode;
} }


private PrepareClientSessionResult PrepareClientSession(MqttConnectPacket connectPacket)
private async Task<PrepareClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
{ {
lock (_sessions)
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{ {
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession);
if (isSessionPresent) if (isSessionPresent)
@@ -303,7 +299,7 @@ namespace MQTTnet.Server
{ {
_sessions.Remove(connectPacket.ClientId); _sessions.Remove(connectPacket.ClientId);


clientSession.Stop(MqttClientDisconnectType.Clean);
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);
clientSession.Dispose(); clientSession.Dispose();
clientSession = null; clientSession = null;


@@ -330,7 +326,7 @@ namespace MQTTnet.Server
} }
} }


private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage)
private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientSession sender, MqttApplicationMessage applicationMessage)
{ {
var interceptor = _options.ApplicationMessageInterceptor; var interceptor = _options.ApplicationMessageInterceptor;
if (interceptor == null) if (interceptor == null)
@@ -339,7 +335,7 @@ namespace MQTTnet.Server
} }


var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage); var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage);
interceptor(interceptorContext);
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
return interceptorContext; return interceptorContext;
} }
} }

+ 11
- 6
Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks;
using MQTTnet.Packets; using MQTTnet.Packets;
using MQTTnet.Protocol; using MQTTnet.Protocol;


@@ -20,7 +21,7 @@ namespace MQTTnet.Server
_eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher)); _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher));
} }


public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket)
public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket)
{ {
if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));


@@ -36,7 +37,7 @@ namespace MQTTnet.Server


foreach (var topicFilter in subscribePacket.TopicFilters) foreach (var topicFilter in subscribePacket.TopicFilters)
{ {
var interceptorContext = InterceptSubscribe(topicFilter);
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false);
if (!interceptorContext.AcceptSubscription) if (!interceptorContext.AcceptSubscription)
{ {
result.ResponsePacket.ReturnCodes.Add(MqttSubscribeReturnCode.Failure); result.ResponsePacket.ReturnCodes.Add(MqttSubscribeReturnCode.Failure);
@@ -146,11 +147,15 @@ namespace MQTTnet.Server
} }
} }


private MqttSubscriptionInterceptorContext InterceptSubscribe(TopicFilter topicFilter)
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter)
{ {
var interceptorContext = new MqttSubscriptionInterceptorContext(_clientId, topicFilter);
_options.SubscriptionInterceptor?.Invoke(interceptorContext);
return interceptorContext;
var context = new MqttSubscriptionInterceptorContext(_clientId, topicFilter);
if (_options.SubscriptionInterceptor != null)
{
await _options.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false);
}

return context;
} }


private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels) private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels)


+ 21
- 28
Source/MQTTnet/Server/MqttRetainedMessagesManager.cs View File

@@ -3,11 +3,14 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Diagnostics; using MQTTnet.Diagnostics;
using MQTTnet.Internal;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
public class MqttRetainedMessagesManager public class MqttRetainedMessagesManager
{ {
private readonly byte[] _emptyArray = new byte[0];
private readonly AsyncLock _messagesLock = new AsyncLock();
private readonly Dictionary<string, MqttApplicationMessage> _messages = new Dictionary<string, MqttApplicationMessage>(); private readonly Dictionary<string, MqttApplicationMessage> _messages = new Dictionary<string, MqttApplicationMessage>();


private readonly IMqttNetChildLogger _logger; private readonly IMqttNetChildLogger _logger;
@@ -31,7 +34,7 @@ namespace MQTTnet.Server
{ {
var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false);


lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{ {
_messages.Clear(); _messages.Clear();
foreach (var retainedMessage in retainedMessages) foreach (var retainedMessage in retainedMessages)
@@ -52,8 +55,7 @@ namespace MQTTnet.Server


try try
{ {
List<MqttApplicationMessage> messagesForSave = null;
lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{ {
var saveIsRequired = false; var saveIsRequired = false;


@@ -71,7 +73,7 @@ namespace MQTTnet.Server
} }
else else
{ {
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0]))
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? _emptyArray))
{ {
_messages[applicationMessage.Topic] = applicationMessage; _messages[applicationMessage.Topic] = applicationMessage;
saveIsRequired = true; saveIsRequired = true;
@@ -83,20 +85,13 @@ namespace MQTTnet.Server


if (saveIsRequired) if (saveIsRequired)
{ {
messagesForSave = new List<MqttApplicationMessage>(_messages.Values);
if (_options.Storage != null)
{
var messagesForSave = new List<MqttApplicationMessage>(_messages.Values);
await _options.Storage.SaveRetainedMessagesAsync(messagesForSave).ConfigureAwait(false);
}
} }
} }

if (messagesForSave == null)
{
_logger.Verbose("Skipped saving retained messages because no changes were detected.");
return;
}

if (_options.Storage != null)
{
await _options.Storage.SaveRetainedMessagesAsync(messagesForSave).ConfigureAwait(false);
}
} }
catch (Exception exception) catch (Exception exception)
{ {
@@ -104,11 +99,11 @@ namespace MQTTnet.Server
} }
} }


public IList<MqttApplicationMessage> GetSubscribedMessages(ICollection<TopicFilter> topicFilters)
public async Task<List<MqttApplicationMessage>> GetSubscribedMessagesAsync(ICollection<TopicFilter> topicFilters)
{ {
var retainedMessages = new List<MqttApplicationMessage>(); var retainedMessages = new List<MqttApplicationMessage>();


lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{ {
foreach (var retainedMessage in _messages.Values) foreach (var retainedMessage in _messages.Values)
{ {
@@ -128,27 +123,25 @@ namespace MQTTnet.Server
return retainedMessages; return retainedMessages;
} }


public IList<MqttApplicationMessage> GetMessages()
public async Task<List<MqttApplicationMessage>> GetMessagesAsync()
{ {
lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{ {
return _messages.Values.ToList(); return _messages.Values.ToList();
} }
} }


public Task ClearMessagesAsync()
public async Task ClearMessagesAsync()
{ {
lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{ {
_messages.Clear(); _messages.Clear();
}


if (_options.Storage != null)
{
return _options.Storage.SaveRetainedMessagesAsync(new List<MqttApplicationMessage>());
if (_options.Storage != null)
{
await _options.Storage.SaveRetainedMessagesAsync(new List<MqttApplicationMessage>()).ConfigureAwait(false);
}
} }

return Task.FromResult((object)null);
} }
} }
} }

+ 3
- 8
Source/MQTTnet/Server/MqttServer.cs View File

@@ -48,17 +48,12 @@ namespace MQTTnet.Server


public Task<IList<IMqttClientSessionStatus>> GetClientSessionsStatusAsync() public Task<IList<IMqttClientSessionStatus>> GetClientSessionsStatusAsync()
{ {
return Task.FromResult(_clientSessionsManager.GetClientStatus());
}

public IList<IMqttClientSessionStatus> GetClientSessionsStatus()
{
return _clientSessionsManager.GetClientStatus();
return _clientSessionsManager.GetClientStatusAsync();
} }


public IList<MqttApplicationMessage> GetRetainedMessages() public IList<MqttApplicationMessage> GetRetainedMessages()
{ {
return _retainedMessagesManager.GetMessages();
return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult();
} }


public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters) public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
@@ -123,7 +118,7 @@ namespace MQTTnet.Server


_cancellationTokenSource.Cancel(false); _cancellationTokenSource.Cancel(false);
_clientSessionsManager.Stop();
_clientSessionsManager.StopAsync().ConfigureAwait(false);


foreach (var adapter in _adapters) foreach (var adapter in _adapters)
{ {


+ 31
- 0
Source/MQTTnet/Server/MqttServerApplicationMessageInterceptorDelegate.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public class MqttServerApplicationMessageInterceptorDelegate : IMqttServerApplicationMessageInterceptor
{
private readonly Func<MqttApplicationMessageInterceptorContext, Task> _callback;

public MqttServerApplicationMessageInterceptorDelegate(Action<MqttApplicationMessageInterceptorContext> callback)
{
if (callback == null) throw new ArgumentNullException(nameof(callback));

_callback = context =>
{
callback(context);
return Task.FromResult(0);
};
}

public MqttServerApplicationMessageInterceptorDelegate(Func<MqttApplicationMessageInterceptorContext, Task> callback)
{
_callback = callback ?? throw new ArgumentNullException(nameof(callback));
}

public Task InterceptApplicationMessagePublishAsync(MqttApplicationMessageInterceptorContext context)
{
return _callback(context);
}
}
}

+ 31
- 0
Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public class MqttServerConnectionValidatorDelegate : IMqttServerConnectionValidator
{
private readonly Func<MqttConnectionValidatorContext, Task> _callback;

public MqttServerConnectionValidatorDelegate(Action<MqttConnectionValidatorContext> callback)
{
if (callback == null) throw new ArgumentNullException(nameof(callback));

_callback = context =>
{
callback(context);
return Task.FromResult(0);
};
}

public MqttServerConnectionValidatorDelegate(Func<MqttConnectionValidatorContext, Task> callback)
{
_callback = callback ?? throw new ArgumentNullException(nameof(callback));
}

public Task ValidateConnection(MqttConnectionValidatorContext context)
{
return _callback(context);
}
}
}

+ 5
- 5
Source/MQTTnet/Server/MqttServerOptions.cs View File

@@ -16,13 +16,13 @@ namespace MQTTnet.Server


public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15); public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15);


public Action<MqttConnectionValidatorContext> ConnectionValidator { get; set; }
public IMqttServerConnectionValidator ConnectionValidator { get; set; }


public Action<MqttApplicationMessageInterceptorContext> ApplicationMessageInterceptor { get; set; }
public IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; set; }
public IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; set; }


public Action<MqttClientMessageQueueInterceptorContext> ClientMessageQueueInterceptor { get; set; }

public Action<MqttSubscriptionInterceptorContext> SubscriptionInterceptor { get; set; }
public IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; set; }


public IMqttServerStorage Storage { get; set; } public IMqttServerStorage Storage { get; set; }
} }


+ 21
- 3
Source/MQTTnet/Server/MqttServerOptionsBuilder.cs View File

@@ -99,24 +99,42 @@ namespace MQTTnet.Server
return this; return this;
} }


public MqttServerOptionsBuilder WithConnectionValidator(Action<MqttConnectionValidatorContext> value)
public MqttServerOptionsBuilder WithConnectionValidator(IMqttServerConnectionValidator value)
{ {
_options.ConnectionValidator = value; _options.ConnectionValidator = value;
return this; return this;
} }


public MqttServerOptionsBuilder WithApplicationMessageInterceptor(Action<MqttApplicationMessageInterceptorContext> value)
public MqttServerOptionsBuilder WithConnectionValidator(Action<MqttConnectionValidatorContext> value)
{
_options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(value);
return this;
}

public MqttServerOptionsBuilder WithApplicationMessageInterceptor(IMqttServerApplicationMessageInterceptor value)
{ {
_options.ApplicationMessageInterceptor = value; _options.ApplicationMessageInterceptor = value;
return this; return this;
} }


public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action<MqttSubscriptionInterceptorContext> value)
public MqttServerOptionsBuilder WithApplicationMessageInterceptor(Action<MqttApplicationMessageInterceptorContext> value)
{
_options.ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(value);
return this;
}

public MqttServerOptionsBuilder WithSubscriptionInterceptor(IMqttServerSubscriptionInterceptor value)
{ {
_options.SubscriptionInterceptor = value; _options.SubscriptionInterceptor = value;
return this; return this;
} }


public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action<MqttSubscriptionInterceptorContext> value)
{
_options.SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(value);
return this;
}

public MqttServerOptionsBuilder WithPersistentSessions() public MqttServerOptionsBuilder WithPersistentSessions()
{ {
_options.EnablePersistentSessions = true; _options.EnablePersistentSessions = true;


+ 31
- 0
Source/MQTTnet/Server/MqttServerSubscriptionInterceptorDelegate.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public class MqttServerSubscriptionInterceptorDelegate : IMqttServerSubscriptionInterceptor
{
private readonly Func<MqttSubscriptionInterceptorContext, Task> _callback;

public MqttServerSubscriptionInterceptorDelegate(Action<MqttSubscriptionInterceptorContext> callback)
{
if (callback == null) throw new ArgumentNullException(nameof(callback));

_callback = context =>
{
callback(context);
return Task.FromResult(0);
};
}

public MqttServerSubscriptionInterceptorDelegate(Func<MqttSubscriptionInterceptorContext, Task> callback)
{
_callback = callback ?? throw new ArgumentNullException(nameof(callback));
}

public Task InterceptSubscriptionAsync(MqttSubscriptionInterceptorContext context)
{
return _callback(context);
}
}
}

+ 1
- 1
Tests/MQTTnet.Core.Tests/AsyncLockTests.cs View File

@@ -21,7 +21,7 @@ namespace MQTTnet.Tests
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
threads[i] = Task.Run(async () => threads[i] = Task.Run(async () =>
{ {
using (var releaser = await @lock.LockAsync(CancellationToken.None))
using (var releaser = await @lock.WaitAsync(CancellationToken.None))
{ {
var localI = globalI; var localI = globalI;
await Task.Delay(10); // Increase the chance for wrong data. await Task.Delay(10); // Increase the chance for wrong data.


+ 2
- 1
Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs View File

@@ -82,9 +82,10 @@ namespace MQTTnet.Tests
throw new NotSupportedException(); throw new NotSupportedException();
} }


public void Stop(MqttClientDisconnectType disconnectType)
public Task StopAsync(MqttClientDisconnectType disconnectType)
{ {
StopCalledCount++; StopCalledCount++;
return Task.FromResult(0);
} }


public Task SubscribeAsync(IList<TopicFilter> topicFilters) public Task SubscribeAsync(IList<TopicFilter> topicFilters)


+ 101
- 34
Tests/MQTTnet.Core.Tests/MqttServerTests.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net.Sockets;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -52,7 +53,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_WillMessage()
public async Task MqttServer_Will_Message()
{ {
var serverAdapter = new TestMqttServerAdapter(); var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -84,7 +85,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_SubscribeUnsubscribe()
public async Task MqttServer_Subscribe_Unsubscribe()
{ {
var serverAdapter = new TestMqttServerAdapter(); var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -230,7 +231,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_SessionTakeover()
public async Task MqttServer_Session_Takeover()
{ {
var server = new MqttFactory().CreateMqttServer(); var server = new MqttFactory().CreateMqttServer();
try try
@@ -299,38 +300,43 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_HandleCleanDisconnect()
public async Task MqttServer_Handle_Clean_Disconnect()
{ {
var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger());
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
var clientConnectedCalled = 0;
var clientDisconnectedCalled = 0;
var s = new MqttFactory().CreateMqttServer();
try
{
var clientConnectedCalled = 0;
var clientDisconnectedCalled = 0;


s.ClientConnected += (_, __) => clientConnectedCalled++;
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++;
s.ClientConnected += (_, __) => clientConnectedCalled++;
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++;


var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();
var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();


await s.StartAsync(new MqttServerOptions());
await s.StartAsync(new MqttServerOptions());


var c1 = new MqttFactory().CreateMqttClient();
var c1 = new MqttFactory().CreateMqttClient();


await c1.ConnectAsync(clientOptions);
await c1.ConnectAsync(clientOptions);


await Task.Delay(100);
await Task.Delay(100);


await c1.DisconnectAsync();
await c1.DisconnectAsync();


await Task.Delay(100);
await Task.Delay(100);


await s.StopAsync();
await s.StopAsync();


await Task.Delay(100);
await Task.Delay(100);


Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
}
finally
{
await s.StopAsync();
}
} }


[TestMethod] [TestMethod]
@@ -410,11 +416,13 @@ namespace MQTTnet.Tests
.WithPayload("value" + j).WithRetainFlag().Build()).GetAwaiter().GetResult(); .WithPayload("value" + j).WithRetainFlag().Build()).GetAwaiter().GetResult();
} }


Thread.Sleep(100);

client.DisconnectAsync().GetAwaiter().GetResult(); client.DisconnectAsync().GetAwaiter().GetResult();
} }
}); });


await Task.Delay(100);
await Task.Delay(1000);


var retainedMessages = server.GetRetainedMessages(); var retainedMessages = server.GetRetainedMessages();


@@ -432,7 +440,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_RetainedMessagesFlow()
public async Task MqttServer_Retained_Messages_Flow()
{ {
var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build();
var serverAdapter = new TestMqttServerAdapter(); var serverAdapter = new TestMqttServerAdapter();
@@ -440,9 +448,9 @@ namespace MQTTnet.Tests
await s.StartAsync(new MqttServerOptions()); await s.StartAsync(new MqttServerOptions());
var c1 = await serverAdapter.ConnectTestClient("c1"); var c1 = await serverAdapter.ConnectTestClient("c1");
await c1.PublishAsync(retainedMessage); await c1.PublishAsync(retainedMessage);
Thread.Sleep(500);
await Task.Delay(500);
await c1.DisconnectAsync(); await c1.DisconnectAsync();
Thread.Sleep(500);
await Task.Delay(500);


var receivedMessages = 0; var receivedMessages = 0;
var c2 = await serverAdapter.ConnectTestClient("c2"); var c2 = await serverAdapter.ConnectTestClient("c2");
@@ -468,7 +476,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_NoRetainedMessage()
public async Task MqttServer_No_Retained_Message()
{ {
var serverAdapter = new TestMqttServerAdapter(); var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -498,7 +506,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_RetainedMessage()
public async Task MqttServer_Retained_Message()
{ {
var serverAdapter = new TestMqttServerAdapter(); var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -535,7 +543,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_ClearRetainedMessage()
public async Task MqttServer_Clear_Retained_Message()
{ {
var serverAdapter = new TestMqttServerAdapter(); var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -567,7 +575,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_PersistRetainedMessage()
public async Task MqttServer_Persist_Retained_Message()
{ {
var storage = new TestStorage(); var storage = new TestStorage();
var serverAdapter = new TestMqttServerAdapter(); var serverAdapter = new TestMqttServerAdapter();
@@ -629,7 +637,7 @@ namespace MQTTnet.Tests


try try
{ {
var options = new MqttServerOptions { ApplicationMessageInterceptor = Interceptor };
var options = new MqttServerOptions { ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(c => Interceptor(c)) };


await s.StartAsync(options); await s.StartAsync(options);


@@ -692,7 +700,7 @@ namespace MQTTnet.Tests
} }


[TestMethod] [TestMethod]
public async Task MqttServer_ConnectionDenied()
public async Task MqttServer_Connection_Denied()
{ {
var server = new MqttFactory().CreateMqttServer(); var server = new MqttFactory().CreateMqttServer();
var client = new MqttFactory().CreateMqttClient(); var client = new MqttFactory().CreateMqttClient();
@@ -791,7 +799,6 @@ namespace MQTTnet.Tests
Assert.AreEqual("cdcd", flow); Assert.AreEqual("cdcd", flow);
} }



[TestMethod] [TestMethod]
public async Task MqttServer_StopAndRestart() public async Task MqttServer_StopAndRestart()
{ {
@@ -820,6 +827,66 @@ namespace MQTTnet.Tests
await server.StopAsync(); await server.StopAsync();
} }


[TestMethod]
public async Task MqttServer_Close_Idle_Connection()
{
var server = new MqttFactory().CreateMqttServer();

try
{
await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build());

var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync("localhost", 1883);

// Don't send anything. The server should close the connection.
await Task.Delay(TimeSpan.FromSeconds(5));

try
{
await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
Assert.Fail("Receive should throw an exception.");
}
catch (SocketException)
{
}
}
finally
{
await server.StopAsync();
}
}

[TestMethod]
public async Task MqttServer_Send_Garbage()
{
var server = new MqttFactory().CreateMqttServer();

try
{
await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build());

var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync("localhost", 1883);
await client.SendAsync(Encoding.UTF8.GetBytes("Garbage"), SocketFlags.None);

await Task.Delay(TimeSpan.FromSeconds(5));

try
{
await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
Assert.Fail("Receive should throw an exception.");
}
catch (SocketException)
{
}
}
finally
{
await server.StopAsync();
}
}

private class TestStorage : IMqttServerStorage private class TestStorage : IMqttServerStorage
{ {
public IList<MqttApplicationMessage> Messages = new List<MqttApplicationMessage>(); public IList<MqttApplicationMessage> Messages = new List<MqttApplicationMessage>();


+ 5
- 5
Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs View File

@@ -16,7 +16,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());


sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();


var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce);
Assert.IsTrue(result.IsSubscribed); Assert.IsTrue(result.IsSubscribed);
@@ -31,7 +31,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });


sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();


var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed); Assert.IsTrue(result.IsSubscribed);
@@ -47,7 +47,7 @@ namespace MQTTnet.Tests
sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce });


sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();


var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed); Assert.IsTrue(result.IsSubscribed);
@@ -62,7 +62,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());


sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();


Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);
} }
@@ -75,7 +75,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());


sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();


Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);




+ 7
- 6
Tests/MQTTnet.TestApp.NetCore/ServerTest.cs View File

@@ -23,7 +23,7 @@ namespace MQTTnet.TestApp.NetCore
{ {
var options = new MqttServerOptions var options = new MqttServerOptions
{ {
ConnectionValidator = p =>
ConnectionValidator = new MqttServerConnectionValidatorDelegate(p =>
{ {
if (p.ClientId == "SpecialClient") if (p.ClientId == "SpecialClient")
{ {
@@ -32,11 +32,11 @@ namespace MQTTnet.TestApp.NetCore
p.ReturnCode = MqttConnectReturnCode.ConnectionRefusedBadUsernameOrPassword; p.ReturnCode = MqttConnectReturnCode.ConnectionRefusedBadUsernameOrPassword;
} }
} }
},
}),


Storage = new RetainedMessageHandler(), Storage = new RetainedMessageHandler(),


ApplicationMessageInterceptor = context =>
ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(context =>
{ {
if (MqttTopicFilterComparer.IsMatch(context.ApplicationMessage.Topic, "/myTopic/WithTimestamp/#")) if (MqttTopicFilterComparer.IsMatch(context.ApplicationMessage.Topic, "/myTopic/WithTimestamp/#"))
{ {
@@ -50,8 +50,9 @@ namespace MQTTnet.TestApp.NetCore
context.AcceptPublish = false; context.AcceptPublish = false;
context.CloseConnection = true; context.CloseConnection = true;
} }
},
SubscriptionInterceptor = context =>
}),

SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(context =>
{ {
if (context.TopicFilter.Topic.StartsWith("admin/foo/bar") && context.ClientId != "theAdmin") if (context.TopicFilter.Topic.StartsWith("admin/foo/bar") && context.ClientId != "theAdmin")
{ {
@@ -63,7 +64,7 @@ namespace MQTTnet.TestApp.NetCore
context.AcceptSubscription = false; context.AcceptSubscription = false;
context.CloseConnection = true; context.CloseConnection = true;
} }
}
})
}; };


// Extend the timestamp for all messages from clients. // Extend the timestamp for all messages from clients.


+ 7
- 7
Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs View File

@@ -450,7 +450,7 @@ namespace MQTTnet.TestApp.UniversalWindows
return; return;
} }


var sessions = _mqttServer.GetClientSessionsStatus();
var sessions = _mqttServer.GetClientSessionsStatusAsync().GetAwaiter().GetResult();
_sessions.Clear(); _sessions.Clear();


foreach (var session in sessions) foreach (var session in sessions)
@@ -568,7 +568,7 @@ namespace MQTTnet.TestApp.UniversalWindows
{ {
var options = new MqttServerOptions(); var options = new MqttServerOptions();


options.ConnectionValidator = c =>
options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(c =>
{ {
if (c.ClientId.Length < 10) if (c.ClientId.Length < 10)
{ {
@@ -589,7 +589,7 @@ namespace MQTTnet.TestApp.UniversalWindows
} }


c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
};
});


var factory = new MqttFactory(); var factory = new MqttFactory();
var mqttServer = factory.CreateMqttServer(); var mqttServer = factory.CreateMqttServer();
@@ -633,7 +633,7 @@ namespace MQTTnet.TestApp.UniversalWindows
{ {
}; };


options.ConnectionValidator = c =>
options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(c =>
{ {
if (c.ClientId != "Highlander") if (c.ClientId != "Highlander")
{ {
@@ -642,7 +642,7 @@ namespace MQTTnet.TestApp.UniversalWindows
} }


c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
};
});


var mqttServer = new MqttFactory().CreateMqttServer(); var mqttServer = new MqttFactory().CreateMqttServer();
await mqttServer.StartAsync(optionsBuilder.Build()); await mqttServer.StartAsync(optionsBuilder.Build());
@@ -652,7 +652,7 @@ namespace MQTTnet.TestApp.UniversalWindows
// Setup client validator. // Setup client validator.
var options = new MqttServerOptions var options = new MqttServerOptions
{ {
ConnectionValidator = c =>
ConnectionValidator = new MqttServerConnectionValidatorDelegate(c =>
{ {
if (c.ClientId.Length < 10) if (c.ClientId.Length < 10)
{ {
@@ -673,7 +673,7 @@ namespace MQTTnet.TestApp.UniversalWindows
} }


c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
}
})
}; };
} }




Loading…
Cancel
Save