Procházet zdrojové kódy

Refactor interceptors etc. to support async.

release/3.x.x
Christian Kratky před 6 roky
rodič
revize
c0507fcc55
33 změnil soubory, kde provedl 542 přidání a 316 odebrání
  1. +2
    -2
      Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs
  2. +1
    -1
      Source/MQTTnet/Formatter/MqttPacketReader.cs
  3. +16
    -10
      Source/MQTTnet/Formatter/MqttPacketWriter.cs
  4. +2
    -2
      Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs
  5. +2
    -2
      Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs
  6. +24
    -10
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  7. +16
    -7
      Source/MQTTnet/Internal/AsyncLock.cs
  8. +2
    -1
      Source/MQTTnet/Server/IMqttClientSession.cs
  9. +0
    -3
      Source/MQTTnet/Server/IMqttServer.cs
  10. +9
    -0
      Source/MQTTnet/Server/IMqttServerApplicationMessageInterceptor.cs
  11. +9
    -0
      Source/MQTTnet/Server/IMqttServerClientMessageQueueInterceptor.cs
  12. +9
    -0
      Source/MQTTnet/Server/IMqttServerConnectionValidator.cs
  13. +4
    -4
      Source/MQTTnet/Server/IMqttServerOptions.cs
  14. +9
    -0
      Source/MQTTnet/Server/IMqttServerSubscriptionInterceptor.cs
  15. +1
    -1
      Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs
  16. +2
    -2
      Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs
  17. +107
    -110
      Source/MQTTnet/Server/MqttClientSession.cs
  18. +4
    -7
      Source/MQTTnet/Server/MqttClientSessionStatus.cs
  19. +46
    -50
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  20. +11
    -6
      Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs
  21. +21
    -28
      Source/MQTTnet/Server/MqttRetainedMessagesManager.cs
  22. +3
    -8
      Source/MQTTnet/Server/MqttServer.cs
  23. +31
    -0
      Source/MQTTnet/Server/MqttServerApplicationMessageInterceptorDelegate.cs
  24. +31
    -0
      Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs
  25. +5
    -5
      Source/MQTTnet/Server/MqttServerOptions.cs
  26. +21
    -3
      Source/MQTTnet/Server/MqttServerOptionsBuilder.cs
  27. +31
    -0
      Source/MQTTnet/Server/MqttServerSubscriptionInterceptorDelegate.cs
  28. +1
    -1
      Tests/MQTTnet.Core.Tests/AsyncLockTests.cs
  29. +2
    -1
      Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs
  30. +101
    -34
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs
  31. +5
    -5
      Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs
  32. +7
    -6
      Tests/MQTTnet.TestApp.NetCore/ServerTest.cs
  33. +7
    -7
      Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs

+ 2
- 2
Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs Zobrazit soubor

@@ -30,7 +30,7 @@ namespace MQTTnet.Extensions.ManagedClient
{
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));

using (await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
using (await _messagesLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
_messages.Add(applicationMessage);
await SaveAsync().ConfigureAwait(false);
@@ -41,7 +41,7 @@ namespace MQTTnet.Extensions.ManagedClient
{
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));

using (await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
using (await _messagesLock.WaitAsync(CancellationToken.None).ConfigureAwait(false))
{
var index = _messages.IndexOf(applicationMessage);
if (index == -1)


+ 1
- 1
Source/MQTTnet/Formatter/MqttPacketReader.cs Zobrazit soubor

@@ -63,7 +63,7 @@ namespace MQTTnet.Formatter
{
var offset = 0;
var multiplier = 128;
var value = (initialEncodedByte & 127);
var value = initialEncodedByte & 127;
int encodedByte = initialEncodedByte;

while ((encodedByte & 128) != 0)


+ 16
- 10
Source/MQTTnet/Formatter/MqttPacketWriter.cs Zobrazit soubor

@@ -15,6 +15,7 @@ namespace MQTTnet.Formatter
{
public static int MaxBufferSize = 4096;

// TODO: Consider using the ArrayPool here together with FreeBuffer.
private byte[] _buffer = new byte[128];

private int _offset;
@@ -28,13 +29,18 @@ namespace MQTTnet.Formatter
return (byte)fixedHeader;
}

public static ArraySegment<byte> EncodeVariableByteInteger(uint value)
public static ArraySegment<byte> EncodeVariableLengthInteger(uint value)
{
if (value <= 0)
if (value == 0)
{
return new ArraySegment<byte>(new byte[1], 0, 1);
}

if (value <= 127)
{
return new ArraySegment<byte>(new[] { (byte)value }, 0, 1);
}

var buffer = new byte[4];
var bufferOffset = 0;

@@ -57,7 +63,7 @@ namespace MQTTnet.Formatter

public void WriteVariableLengthInteger(uint value)
{
Write(EncodeVariableByteInteger(value));
Write(EncodeVariableLengthInteger(value));
}

public void WriteWithLengthPrefix(string value)
@@ -80,7 +86,7 @@ namespace MQTTnet.Formatter
EnsureAdditionalCapacity(1);

_buffer[_offset] = @byte;
IncreasePostition(1);
IncreasePosition(1);
}

public void Write(ushort value)
@@ -88,9 +94,9 @@ namespace MQTTnet.Formatter
EnsureAdditionalCapacity(2);

_buffer[_offset] = (byte)(value >> 8);
IncreasePostition(1);
IncreasePosition(1);
_buffer[_offset] = (byte)value;
IncreasePostition(1);
IncreasePosition(1);
}

public void Write(byte[] buffer, int offset, int count)
@@ -100,7 +106,7 @@ namespace MQTTnet.Formatter
EnsureAdditionalCapacity(count);

Array.Copy(buffer, offset, _buffer, _offset, count);
IncreasePostition(count);
IncreasePosition(count);
}

public void Write(ArraySegment<byte> buffer)
@@ -122,9 +128,9 @@ namespace MQTTnet.Formatter
Write(propertyWriter._buffer, 0, propertyWriter.Length);
}

public void Reset()
public void Reset(int length)
{
Length = 5;
Length = length;
}

public void Seek(int position)
@@ -185,7 +191,7 @@ namespace MQTTnet.Formatter
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void IncreasePostition(int length)
private void IncreasePosition(int length)
{
_offset += length;



+ 2
- 2
Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs Zobrazit soubor

@@ -21,13 +21,13 @@ namespace MQTTnet.Formatter.V3
if (packet == null) throw new ArgumentNullException(nameof(packet));

// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes)
_packetWriter.Reset();
_packetWriter.Reset(5);
_packetWriter.Seek(5);

var fixedHeader = EncodePacket(packet, _packetWriter);
var remainingLength = (uint)(_packetWriter.Length - 5);

var remainingLengthBuffer = MqttPacketWriter.EncodeVariableByteInteger(remainingLength);
var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength);

var headerSize = FixedHeaderSize + remainingLengthBuffer.Count;
var headerOffset = 5 - headerSize;


+ 2
- 2
Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs Zobrazit soubor

@@ -15,13 +15,13 @@ namespace MQTTnet.Formatter.V5
if (packet == null) throw new ArgumentNullException(nameof(packet));

// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes)
_packetWriter.Reset();
_packetWriter.Reset(5);
_packetWriter.Seek(5);

var fixedHeader = EncodePacket(packet, _packetWriter);
var remainingLength = (uint)(_packetWriter.Length - 5);

var remainingLengthBuffer = MqttPacketWriter.EncodeVariableByteInteger(remainingLength);
var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength);

var headerSize = 1 + remainingLengthBuffer.Count;
var headerOffset = 5 - headerSize;


+ 24
- 10
Source/MQTTnet/Implementations/MqttTcpChannel.cs Zobrazit soubor

@@ -52,11 +52,15 @@ namespace MQTTnet.Implementations
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
}

// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
#if NET452 || NET461
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false);
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false);
#else
await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false);
await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false);
#endif
}

SslStream sslStream = null;
if (_options.TlsOptions.UseTls)
@@ -74,14 +78,23 @@ namespace MQTTnet.Implementations
return Task.FromResult(0);
}

public Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _stream.ReadAsync(buffer, offset, count, cancellationToken);
// Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
}

public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _stream.WriteAsync(buffer, offset, count, cancellationToken);
// Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
await _stream.FlushAsync(cancellationToken);
}
}

public void Dispose()
@@ -89,10 +102,11 @@ namespace MQTTnet.Implementations
Cleanup(ref _stream, s => s.Dispose());
Cleanup(ref _socket, s =>
{
if (s.Connected)
{
s.Shutdown(SocketShutdown.Both);
}
//if (s.Connected)
//{
// s.Shutdown(SocketShutdown.Both);
//}

s.Dispose();
});
}


+ 16
- 7
Source/MQTTnet/Internal/AsyncLock.cs Zobrazit soubor

@@ -15,14 +15,23 @@ namespace MQTTnet.Internal
_releaser = Task.FromResult((IDisposable)new Releaser(this));
}

public Task<IDisposable> LockAsync(CancellationToken cancellationToken)
public Task<IDisposable> WaitAsync()
{
var wait = _semaphore.WaitAsync(cancellationToken);
return wait.IsCompleted ?
_releaser :
wait.ContinueWith((_, state) => (IDisposable)state,
_releaser.Result, cancellationToken,
TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
return WaitAsync(CancellationToken.None);
}

public Task<IDisposable> WaitAsync(CancellationToken cancellationToken)
{
var task = _semaphore.WaitAsync(cancellationToken);
if (task.IsCompleted)
{
return _releaser;
}

return task.ContinueWith(
(_, state) => (IDisposable)state,
_releaser.Result,
cancellationToken, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
}

public void Dispose()


+ 2
- 1
Source/MQTTnet/Server/IMqttClientSession.cs Zobrazit soubor

@@ -1,4 +1,5 @@
using System;
using System.Threading.Tasks;

namespace MQTTnet.Server
{
@@ -6,6 +7,6 @@ namespace MQTTnet.Server
{
string ClientId { get; }

void Stop(MqttClientDisconnectType disconnectType);
Task StopAsync(MqttClientDisconnectType disconnectType);
}
}

+ 0
- 3
Source/MQTTnet/Server/IMqttServer.cs Zobrazit soubor

@@ -16,11 +16,8 @@ namespace MQTTnet.Server
IMqttServerOptions Options { get; }

[Obsolete("This method is no longer async. Use the not async method.")]
Task<IList<IMqttClientSessionStatus>> GetClientSessionsStatusAsync();

IList<IMqttClientSessionStatus> GetClientSessionsStatus();

IList<MqttApplicationMessage> GetRetainedMessages();
Task ClearRetainedMessagesAsync();



+ 9
- 0
Source/MQTTnet/Server/IMqttServerApplicationMessageInterceptor.cs Zobrazit soubor

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerApplicationMessageInterceptor
{
Task InterceptApplicationMessagePublishAsync(MqttApplicationMessageInterceptorContext context);
}
}

+ 9
- 0
Source/MQTTnet/Server/IMqttServerClientMessageQueueInterceptor.cs Zobrazit soubor

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerClientMessageQueueInterceptor
{
Task InterceptClientMessageQueueEnqueueAsync(MqttClientMessageQueueInterceptorContext context);
}
}

+ 9
- 0
Source/MQTTnet/Server/IMqttServerConnectionValidator.cs Zobrazit soubor

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerConnectionValidator
{
Task ValidateConnection(MqttConnectionValidatorContext context);
}
}

+ 4
- 4
Source/MQTTnet/Server/IMqttServerOptions.cs Zobrazit soubor

@@ -11,10 +11,10 @@ namespace MQTTnet.Server

TimeSpan DefaultCommunicationTimeout { get; }

Action<MqttConnectionValidatorContext> ConnectionValidator { get; }
Action<MqttSubscriptionInterceptorContext> SubscriptionInterceptor { get; }
Action<MqttApplicationMessageInterceptorContext> ApplicationMessageInterceptor { get; }
Action<MqttClientMessageQueueInterceptorContext> ClientMessageQueueInterceptor { get; }
IMqttServerConnectionValidator ConnectionValidator { get; }
IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; }
IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; }
IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; }

MqttServerTcpEndpointOptions DefaultEndpointOptions { get; }
MqttServerTlsTcpEndpointOptions TlsEndpointOptions { get; }


+ 9
- 0
Source/MQTTnet/Server/IMqttServerSubscriptionInterceptor.cs Zobrazit soubor

@@ -0,0 +1,9 @@
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public interface IMqttServerSubscriptionInterceptor
{
Task InterceptSubscriptionAsync(MqttSubscriptionInterceptorContext context);
}
}

+ 1
- 1
Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs Zobrazit soubor

@@ -81,7 +81,7 @@ namespace MQTTnet.Server
if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds >= keepAlivePeriod * 1.5D)
{
_logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientSession.ClientId);
_clientSession.Stop(MqttClientDisconnectType.NotClean);
await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);

return;
}


+ 2
- 2
Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs Zobrazit soubor

@@ -133,7 +133,7 @@ namespace MQTTnet.Server
return;
}

adapter.SendPacketAsync(packet, cancellationToken).GetAwaiter().GetResult();
await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false);

_logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId);
}
@@ -167,7 +167,7 @@ namespace MQTTnet.Server

if (!cancellationToken.IsCancellationRequested)
{
_clientSession.Stop(MqttClientDisconnectType.NotClean);
await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
}
}
}


+ 107
- 110
Source/MQTTnet/Server/MqttClientSession.cs Zobrazit soubor

@@ -79,22 +79,20 @@ namespace MQTTnet.Server
return _workerTask;
}

public void Stop(MqttClientDisconnectType type)
public Task StopAsync(MqttClientDisconnectType type)
{
Stop(type, false);
return StopAsync(type, false);
}

public Task SubscribeAsync(IList<TopicFilter> topicFilters)
public async Task SubscribeAsync(IList<TopicFilter> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

var packet = new MqttSubscribePacket();
packet.TopicFilters.AddRange(topicFilters);

_subscriptionsManager.Subscribe(packet);

EnqueueSubscribedRetainedMessages(topicFilters);
return Task.FromResult(0);
await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false);
await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false);
}

public Task UnsubscribeAsync(IList<string> topicFilters)
@@ -122,7 +120,7 @@ namespace MQTTnet.Server
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;
}
private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter)
{
if (adapter == null) throw new ArgumentNullException(nameof(adapter));
@@ -131,7 +129,7 @@ namespace MQTTnet.Server
{
if (_cancellationTokenSource != null)
{
Stop(MqttClientDisconnectType.Clean, true);
await StopAsync(MqttClientDisconnectType.Clean, true).ConfigureAwait(false);
}

adapter.ReadingPacketStarted += OnAdapterReadingPacketStarted;
@@ -139,15 +137,15 @@ namespace MQTTnet.Server

_cancellationTokenSource = new CancellationTokenSource();

//workaround for https://github.com/dotnet/corefx/issues/24430
#pragma warning disable 4014
_cleanupHandle = _cancellationTokenSource.Token.Register(async () =>
{
await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
TryDisposeAdapter(adapter);
});
#pragma warning restore 4014
//end workaround
// //workaround for https://github.com/dotnet/corefx/issues/24430
//#pragma warning disable 4014
// _cleanupHandle = _cancellationTokenSource.Token.Register(async () =>
// {
// await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
// TryDisposeAdapter(adapter);
// });
//#pragma warning restore 4014
// //end workaround

_wasCleanDisconnect = false;
_willMessage = connectPacket.WillMessage;
@@ -165,7 +163,7 @@ namespace MQTTnet.Server
if (packet != null)
{
_keepAliveMonitor.PacketReceived(packet);
ProcessReceivedPacket(adapter, packet, _cancellationTokenSource.Token);
await ProcessReceivedPacketAsync(adapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false);
}
}
}
@@ -190,51 +188,55 @@ namespace MQTTnet.Server
_logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId);
}

Stop(MqttClientDisconnectType.NotClean, true);
await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
}
finally
{
adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;

_adapterEndpoint = null;
_adapterProtocolVersion = null;

// Uncomment as soon as the workaround above is no longer needed.
// Also called in outer scope!
//await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
//TryDisposeAdapter(adapter);

_cleanupHandle?.Dispose();
_cleanupHandle = null;
_cancellationTokenSource?.Cancel(false);
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;
}
}

private void TryDisposeAdapter(IMqttChannelAdapter adapter)
{
if (adapter == null)
{
return;
}
try
{
adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
adapter.Dispose();
}
catch (Exception exception)
{
_logger.Error(exception, exception.Message);
}
finally
{
adapter.Dispose();
}
}
private void Stop(MqttClientDisconnectType type, bool isInsideSession)
////private void TryDisposeAdapter(IMqttChannelAdapter adapter)
////{
//// if (adapter == null)
//// {
//// return;
//// }
//// try
//// {
//// adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
//// adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;
//// adapter.Dispose();
//// }
//// catch (Exception exception)
//// {
//// _logger.Error(exception, exception.Message);
//// }
//// finally
//// {
//// adapter.Dispose();
//// }
////}
private async Task StopAsync(MqttClientDisconnectType type, bool isInsideSession)
{
try
{
@@ -257,7 +259,10 @@ namespace MQTTnet.Server

if (!isInsideSession)
{
_workerTask?.GetAwaiter().GetResult();
if (_workerTask != null)
{
await _workerTask.ConfigureAwait(false);
}
}
}
finally
@@ -267,7 +272,7 @@ namespace MQTTnet.Server
}
}

public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage, bool isRetainedApplicationMessage)
public async Task EnqueueApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage, bool isRetainedApplicationMessage)
{
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));

@@ -278,10 +283,10 @@ namespace MQTTnet.Server
}

var publishPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreatePublishPacket(applicationMessage);
// Set the retain flag to true according to [MQTT-3.3.1-8] and [MQTT-3.3.1-9].
publishPacket.Retain = isRetainedApplicationMessage;
if (publishPacket.QualityOfServiceLevel > 0)
{
publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier();
@@ -294,7 +299,10 @@ namespace MQTTnet.Server
ClientId,
applicationMessage);

_options.ClientMessageQueueInterceptor?.Invoke(context);
if (_options.ClientMessageQueueInterceptor != null)
{
await _options.ClientMessageQueueInterceptor.InterceptClientMessageQueueEnqueueAsync(context).ConfigureAwait(false);
}

if (!context.AcceptEnqueue || context.ApplicationMessage == null)
{
@@ -309,35 +317,33 @@ namespace MQTTnet.Server
_pendingPacketsQueue.Enqueue(publishPacket);
}

private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter)
{
if (adapter == null)
{
return;
}
try
{
await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while disconnecting channel adapter.");
}
}
private void ProcessReceivedPacket(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
//private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter)
//{
// if (adapter == null)
// {
// return;
// }
// try
// {
// await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
// }
// catch (Exception exception)
// {
// _logger.Error(exception, "Error while disconnecting channel adapter.");
// }
//}
private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
{
if (packet is MqttPublishPacket publishPacket)
{
HandleIncomingPublishPacket(adapter, publishPacket, cancellationToken);
return;
return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken);
}

if (packet is MqttPingReqPacket)
{
adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken).GetAwaiter().GetResult();
return;
return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
}

if (packet is MqttPubRelPacket pubRelPacket)
@@ -348,8 +354,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubCompReasonCode.Success
};

adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult();
return;
return adapter.SendPacketAsync(responsePacket, cancellationToken);
}

if (packet is MqttPubRecPacket pubRecPacket)
@@ -360,91 +365,83 @@ namespace MQTTnet.Server
ReasonCode = MqttPubRelReasonCode.Success
};

adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult();
return;
return adapter.SendPacketAsync(responsePacket, cancellationToken);
}

if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
{
return;
return Task.FromResult(0);
}

if (packet is MqttSubscribePacket subscribePacket)
{
HandleIncomingSubscribePacket(adapter, subscribePacket, cancellationToken);
return;
return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken);
}

if (packet is MqttUnsubscribePacket unsubscribePacket)
{
HandleIncomingUnsubscribePacket(adapter, unsubscribePacket, cancellationToken);
return;
return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken);
}

if (packet is MqttDisconnectPacket)
{
Stop(MqttClientDisconnectType.Clean, true);
return;
return StopAsync(MqttClientDisconnectType.Clean, true);
}

if (packet is MqttConnectPacket)
{
Stop(MqttClientDisconnectType.NotClean, true);
return;
return StopAsync(MqttClientDisconnectType.NotClean, true);
}

_logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet);

Stop(MqttClientDisconnectType.NotClean, true);
return StopAsync(MqttClientDisconnectType.NotClean, true);
}

private void EnqueueSubscribedRetainedMessages(ICollection<TopicFilter> topicFilters)
private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
{
var retainedMessages = _retainedMessagesManager.GetSubscribedMessages(topicFilters);
var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false);
foreach (var applicationMessage in retainedMessages)
{
EnqueueApplicationMessage(null, applicationMessage, true);
await EnqueueApplicationMessageAsync(null, applicationMessage, true).ConfigureAwait(false);
}
}

private void HandleIncomingSubscribePacket(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
{
var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket);
adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).GetAwaiter().GetResult();
var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false);
await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false);

if (subscribeResult.CloseConnection)
{
Stop(MqttClientDisconnectType.NotClean, true);
return;
await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false);
}

EnqueueSubscribedRetainedMessages(subscribePacket.TopicFilters);
await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false);
}

private void HandleIncomingUnsubscribePacket(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
{
var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket);
adapter.SendPacketAsync(unsubscribeResult, cancellationToken).GetAwaiter().GetResult();
return adapter.SendPacketAsync(unsubscribeResult, cancellationToken);
}

private void HandleIncomingPublishPacket(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{
switch (publishPacket.QualityOfServiceLevel)
{
case MqttQualityOfServiceLevel.AtMostOnce:
{
HandleIncomingPublishPacketWithQoS0(publishPacket);
break;
return Task.FromResult(0);
}
case MqttQualityOfServiceLevel.AtLeastOnce:
{
HandleIncomingPublishPacketWithQoS1(adapter, publishPacket, cancellationToken);
break;
return HandleIncomingPublishPacketWithQoS1Async(adapter, publishPacket, cancellationToken);
}
case MqttQualityOfServiceLevel.ExactlyOnce:
{
HandleIncomingPublishPacketWithQoS2(adapter, publishPacket, cancellationToken);
break;
return HandleIncomingPublishPacketWithQoS2Async(adapter, publishPacket, cancellationToken);
}
default:
{
@@ -456,17 +453,17 @@ namespace MQTTnet.Server
private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket)
{
_sessionsManager.EnqueueApplicationMessage(
this,
this,
_channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket));
}

private void HandleIncomingPublishPacketWithQoS1(
private Task HandleIncomingPublishPacketWithQoS1Async(
IMqttChannelAdapter adapter,
MqttPublishPacket publishPacket,
CancellationToken cancellationToken)
{
_sessionsManager.EnqueueApplicationMessage(
this,
this,
_channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket));

var response = new MqttPubAckPacket
@@ -475,10 +472,10 @@ namespace MQTTnet.Server
ReasonCode = MqttPubAckReasonCode.Success
};

adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult();
return adapter.SendPacketAsync(response, cancellationToken);
}

private void HandleIncomingPublishPacketWithQoS2(
private Task HandleIncomingPublishPacketWithQoS2Async(
IMqttChannelAdapter adapter,
MqttPublishPacket publishPacket,
CancellationToken cancellationToken)
@@ -492,7 +489,7 @@ namespace MQTTnet.Server
ReasonCode = MqttPubRecReasonCode.Success
};

adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult();
return adapter.SendPacketAsync(response, cancellationToken);
}

private void OnAdapterReadingPacketCompleted(object sender, EventArgs e)


+ 4
- 7
Source/MQTTnet/Server/MqttClientSessionStatus.cs Zobrazit soubor

@@ -25,22 +25,19 @@ namespace MQTTnet.Server

public Task DisconnectAsync()
{
_session.Stop(MqttClientDisconnectType.NotClean);
return Task.FromResult(0);
return _session.StopAsync(MqttClientDisconnectType.NotClean);
}

public Task DeleteSessionAsync()
public async Task DeleteSessionAsync()
{
try
{
_session.Stop(MqttClientDisconnectType.NotClean);
await _session.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
}
finally
{
_sessionsManager.DeleteSession(ClientId);
await _sessionsManager.DeleteSessionAsync(ClientId).ConfigureAwait(false);
}

return Task.FromResult(0);
}

public Task ClearPendingApplicationMessagesAsync()


+ 46
- 50
Source/MQTTnet/Server/MqttClientSessionsManager.cs Zobrazit soubor

@@ -1,12 +1,11 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Adapter;
using MQTTnet.Diagnostics;
using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;

@@ -16,9 +15,7 @@ namespace MQTTnet.Server
{
private readonly BlockingCollection<MqttEnqueuedApplicationMessage> _messageQueue = new BlockingCollection<MqttEnqueuedApplicationMessage>();

/// <summary>
/// manual locking dictionaries is faster than using concurrent dictionary
/// </summary>
private readonly AsyncLock _sessionsLock = new AsyncLock();
private readonly Dictionary<string, MqttClientSession> _sessions = new Dictionary<string, MqttClientSession>();

private readonly CancellationToken _cancellationToken;
@@ -47,16 +44,16 @@ namespace MQTTnet.Server

public void Start()
{
Task.Factory.StartNew(() => TryProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default);
Task.Factory.StartNew(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default);
}

public void Stop()
public async Task StopAsync()
{
lock (_sessions)
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{
foreach (var session in _sessions)
{
session.Value.Stop(MqttClientDisconnectType.NotClean);
await session.Value.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
}

_sessions.Clear();
@@ -68,18 +65,21 @@ namespace MQTTnet.Server
return Task.Run(() => RunSessionAsync(clientAdapter, _cancellationToken), _cancellationToken);
}

public IList<IMqttClientSessionStatus> GetClientStatus()
public async Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync()
{
var result = new List<IMqttClientSessionStatus>();

foreach (var session in GetSessions())
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{
var status = new MqttClientSessionStatus(this, session);
session.FillStatus(status);
foreach (var session in _sessions.Values)
{
var status = new MqttClientSessionStatus(this, session);
session.FillStatus(status);

result.Add(status);
result.Add(status);
}
}
return result;
}

@@ -122,9 +122,9 @@ namespace MQTTnet.Server
}
}

public void DeleteSession(string clientId)
public async Task DeleteSessionAsync(string clientId)
{
lock (_sessions)
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{
_sessions.Remove(clientId);
}
@@ -137,13 +137,13 @@ namespace MQTTnet.Server
_messageQueue?.Dispose();
}

private void TryProcessQueuedApplicationMessages(CancellationToken cancellationToken)
private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
TryProcessNextQueuedApplicationMessage(cancellationToken);
await TryProcessNextQueuedApplicationMessageAsync(cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
@@ -155,7 +155,7 @@ namespace MQTTnet.Server
}
}

private void TryProcessNextQueuedApplicationMessage(CancellationToken cancellationToken)
private async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken)
{
try
{
@@ -164,12 +164,12 @@ namespace MQTTnet.Server
var sender = enqueuedApplicationMessage.Sender;
var applicationMessage = enqueuedApplicationMessage.ApplicationMessage;

var interceptorContext = InterceptApplicationMessage(sender, applicationMessage);
var interceptorContext = await InterceptApplicationMessageAsync(sender, applicationMessage).ConfigureAwait(false);
if (interceptorContext != null)
{
if (interceptorContext.CloseConnection)
{
enqueuedApplicationMessage.Sender.Stop(MqttClientDisconnectType.NotClean);
await enqueuedApplicationMessage.Sender.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false);
}

if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish)
@@ -184,12 +184,18 @@ namespace MQTTnet.Server

if (applicationMessage.Retain)
{
_retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult();
await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
}

foreach (var clientSession in GetSessions())
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{
clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, enqueuedApplicationMessage.ApplicationMessage, false);
foreach (var clientSession in _sessions.Values)
{
await clientSession.EnqueueApplicationMessageAsync(
enqueuedApplicationMessage.Sender,
enqueuedApplicationMessage.ApplicationMessage,
false).ConfigureAwait(false);
}
}
}
catch (OperationCanceledException)
@@ -201,35 +207,22 @@ namespace MQTTnet.Server
}
}

private List<MqttClientSession> GetSessions()
{
lock (_sessions)
{
return _sessions.Values.ToList();
}
}

private async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
{
var clientId = string.Empty;
try
{
// TODO: Catch cancel exception here if the first packet was not received and log properly.
var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
if (firstPacket == null)
{
return;
}

if (!(firstPacket is MqttConnectPacket connectPacket))
{
throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1].");
_logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", clientAdapter.Endpoint);
return;
}

clientId = connectPacket.ClientId;

var connectReturnCode = ValidateConnection(connectPacket, clientAdapter);
var connectReturnCode = await ValidateConnectionAsync(connectPacket, clientAdapter).ConfigureAwait(false);
if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
{
await clientAdapter.SendPacketAsync(
@@ -242,7 +235,7 @@ namespace MQTTnet.Server
return;
}

var result = PrepareClientSession(connectPacket);
var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);

await clientAdapter.SendPacketAsync(
new MqttConnAckPacket
@@ -267,14 +260,17 @@ namespace MQTTnet.Server
}
finally
{
await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
clientAdapter.Dispose();

if (!_options.EnablePersistentSessions)
{
DeleteSession(clientId);
await DeleteSessionAsync(clientId).ConfigureAwait(false);
}
}
}

private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter)
private async Task<MqttConnectReturnCode> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter)
{
if (_options.ConnectionValidator == null)
{
@@ -288,13 +284,13 @@ namespace MQTTnet.Server
connectPacket.WillMessage,
clientAdapter.Endpoint);

_options.ConnectionValidator(context);
await _options.ConnectionValidator.ValidateConnection(context).ConfigureAwait(false);
return context.ReturnCode;
}

private PrepareClientSessionResult PrepareClientSession(MqttConnectPacket connectPacket)
private async Task<PrepareClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
{
lock (_sessions)
using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false))
{
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession);
if (isSessionPresent)
@@ -303,7 +299,7 @@ namespace MQTTnet.Server
{
_sessions.Remove(connectPacket.ClientId);

clientSession.Stop(MqttClientDisconnectType.Clean);
await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false);
clientSession.Dispose();
clientSession = null;

@@ -330,7 +326,7 @@ namespace MQTTnet.Server
}
}

private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage)
private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientSession sender, MqttApplicationMessage applicationMessage)
{
var interceptor = _options.ApplicationMessageInterceptor;
if (interceptor == null)
@@ -339,7 +335,7 @@ namespace MQTTnet.Server
}

var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage);
interceptor(interceptorContext);
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
return interceptorContext;
}
}

+ 11
- 6
Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs Zobrazit soubor

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using MQTTnet.Packets;
using MQTTnet.Protocol;

@@ -20,7 +21,7 @@ namespace MQTTnet.Server
_eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher));
}

public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket)
public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket)
{
if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));

@@ -36,7 +37,7 @@ namespace MQTTnet.Server

foreach (var topicFilter in subscribePacket.TopicFilters)
{
var interceptorContext = InterceptSubscribe(topicFilter);
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false);
if (!interceptorContext.AcceptSubscription)
{
result.ResponsePacket.ReturnCodes.Add(MqttSubscribeReturnCode.Failure);
@@ -146,11 +147,15 @@ namespace MQTTnet.Server
}
}

private MqttSubscriptionInterceptorContext InterceptSubscribe(TopicFilter topicFilter)
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter)
{
var interceptorContext = new MqttSubscriptionInterceptorContext(_clientId, topicFilter);
_options.SubscriptionInterceptor?.Invoke(interceptorContext);
return interceptorContext;
var context = new MqttSubscriptionInterceptorContext(_clientId, topicFilter);
if (_options.SubscriptionInterceptor != null)
{
await _options.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false);
}

return context;
}

private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels)


+ 21
- 28
Source/MQTTnet/Server/MqttRetainedMessagesManager.cs Zobrazit soubor

@@ -3,11 +3,14 @@ using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using MQTTnet.Diagnostics;
using MQTTnet.Internal;

namespace MQTTnet.Server
{
public class MqttRetainedMessagesManager
{
private readonly byte[] _emptyArray = new byte[0];
private readonly AsyncLock _messagesLock = new AsyncLock();
private readonly Dictionary<string, MqttApplicationMessage> _messages = new Dictionary<string, MqttApplicationMessage>();

private readonly IMqttNetChildLogger _logger;
@@ -31,7 +34,7 @@ namespace MQTTnet.Server
{
var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false);

lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{
_messages.Clear();
foreach (var retainedMessage in retainedMessages)
@@ -52,8 +55,7 @@ namespace MQTTnet.Server

try
{
List<MqttApplicationMessage> messagesForSave = null;
lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{
var saveIsRequired = false;

@@ -71,7 +73,7 @@ namespace MQTTnet.Server
}
else
{
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0]))
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? _emptyArray))
{
_messages[applicationMessage.Topic] = applicationMessage;
saveIsRequired = true;
@@ -83,20 +85,13 @@ namespace MQTTnet.Server

if (saveIsRequired)
{
messagesForSave = new List<MqttApplicationMessage>(_messages.Values);
if (_options.Storage != null)
{
var messagesForSave = new List<MqttApplicationMessage>(_messages.Values);
await _options.Storage.SaveRetainedMessagesAsync(messagesForSave).ConfigureAwait(false);
}
}
}

if (messagesForSave == null)
{
_logger.Verbose("Skipped saving retained messages because no changes were detected.");
return;
}

if (_options.Storage != null)
{
await _options.Storage.SaveRetainedMessagesAsync(messagesForSave).ConfigureAwait(false);
}
}
catch (Exception exception)
{
@@ -104,11 +99,11 @@ namespace MQTTnet.Server
}
}

public IList<MqttApplicationMessage> GetSubscribedMessages(ICollection<TopicFilter> topicFilters)
public async Task<List<MqttApplicationMessage>> GetSubscribedMessagesAsync(ICollection<TopicFilter> topicFilters)
{
var retainedMessages = new List<MqttApplicationMessage>();

lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{
foreach (var retainedMessage in _messages.Values)
{
@@ -128,27 +123,25 @@ namespace MQTTnet.Server
return retainedMessages;
}

public IList<MqttApplicationMessage> GetMessages()
public async Task<List<MqttApplicationMessage>> GetMessagesAsync()
{
lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{
return _messages.Values.ToList();
}
}

public Task ClearMessagesAsync()
public async Task ClearMessagesAsync()
{
lock (_messages)
using (await _messagesLock.WaitAsync().ConfigureAwait(false))
{
_messages.Clear();
}

if (_options.Storage != null)
{
return _options.Storage.SaveRetainedMessagesAsync(new List<MqttApplicationMessage>());
if (_options.Storage != null)
{
await _options.Storage.SaveRetainedMessagesAsync(new List<MqttApplicationMessage>()).ConfigureAwait(false);
}
}

return Task.FromResult((object)null);
}
}
}

+ 3
- 8
Source/MQTTnet/Server/MqttServer.cs Zobrazit soubor

@@ -48,17 +48,12 @@ namespace MQTTnet.Server

public Task<IList<IMqttClientSessionStatus>> GetClientSessionsStatusAsync()
{
return Task.FromResult(_clientSessionsManager.GetClientStatus());
}

public IList<IMqttClientSessionStatus> GetClientSessionsStatus()
{
return _clientSessionsManager.GetClientStatus();
return _clientSessionsManager.GetClientStatusAsync();
}

public IList<MqttApplicationMessage> GetRetainedMessages()
{
return _retainedMessagesManager.GetMessages();
return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult();
}

public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
@@ -123,7 +118,7 @@ namespace MQTTnet.Server

_cancellationTokenSource.Cancel(false);
_clientSessionsManager.Stop();
_clientSessionsManager.StopAsync().ConfigureAwait(false);

foreach (var adapter in _adapters)
{


+ 31
- 0
Source/MQTTnet/Server/MqttServerApplicationMessageInterceptorDelegate.cs Zobrazit soubor

@@ -0,0 +1,31 @@
using System;
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public class MqttServerApplicationMessageInterceptorDelegate : IMqttServerApplicationMessageInterceptor
{
private readonly Func<MqttApplicationMessageInterceptorContext, Task> _callback;

public MqttServerApplicationMessageInterceptorDelegate(Action<MqttApplicationMessageInterceptorContext> callback)
{
if (callback == null) throw new ArgumentNullException(nameof(callback));

_callback = context =>
{
callback(context);
return Task.FromResult(0);
};
}

public MqttServerApplicationMessageInterceptorDelegate(Func<MqttApplicationMessageInterceptorContext, Task> callback)
{
_callback = callback ?? throw new ArgumentNullException(nameof(callback));
}

public Task InterceptApplicationMessagePublishAsync(MqttApplicationMessageInterceptorContext context)
{
return _callback(context);
}
}
}

+ 31
- 0
Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs Zobrazit soubor

@@ -0,0 +1,31 @@
using System;
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public class MqttServerConnectionValidatorDelegate : IMqttServerConnectionValidator
{
private readonly Func<MqttConnectionValidatorContext, Task> _callback;

public MqttServerConnectionValidatorDelegate(Action<MqttConnectionValidatorContext> callback)
{
if (callback == null) throw new ArgumentNullException(nameof(callback));

_callback = context =>
{
callback(context);
return Task.FromResult(0);
};
}

public MqttServerConnectionValidatorDelegate(Func<MqttConnectionValidatorContext, Task> callback)
{
_callback = callback ?? throw new ArgumentNullException(nameof(callback));
}

public Task ValidateConnection(MqttConnectionValidatorContext context)
{
return _callback(context);
}
}
}

+ 5
- 5
Source/MQTTnet/Server/MqttServerOptions.cs Zobrazit soubor

@@ -16,13 +16,13 @@ namespace MQTTnet.Server

public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15);

public Action<MqttConnectionValidatorContext> ConnectionValidator { get; set; }
public IMqttServerConnectionValidator ConnectionValidator { get; set; }

public Action<MqttApplicationMessageInterceptorContext> ApplicationMessageInterceptor { get; set; }
public IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; set; }
public IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; set; }

public Action<MqttClientMessageQueueInterceptorContext> ClientMessageQueueInterceptor { get; set; }

public Action<MqttSubscriptionInterceptorContext> SubscriptionInterceptor { get; set; }
public IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; set; }

public IMqttServerStorage Storage { get; set; }
}


+ 21
- 3
Source/MQTTnet/Server/MqttServerOptionsBuilder.cs Zobrazit soubor

@@ -99,24 +99,42 @@ namespace MQTTnet.Server
return this;
}

public MqttServerOptionsBuilder WithConnectionValidator(Action<MqttConnectionValidatorContext> value)
public MqttServerOptionsBuilder WithConnectionValidator(IMqttServerConnectionValidator value)
{
_options.ConnectionValidator = value;
return this;
}

public MqttServerOptionsBuilder WithApplicationMessageInterceptor(Action<MqttApplicationMessageInterceptorContext> value)
public MqttServerOptionsBuilder WithConnectionValidator(Action<MqttConnectionValidatorContext> value)
{
_options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(value);
return this;
}

public MqttServerOptionsBuilder WithApplicationMessageInterceptor(IMqttServerApplicationMessageInterceptor value)
{
_options.ApplicationMessageInterceptor = value;
return this;
}

public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action<MqttSubscriptionInterceptorContext> value)
public MqttServerOptionsBuilder WithApplicationMessageInterceptor(Action<MqttApplicationMessageInterceptorContext> value)
{
_options.ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(value);
return this;
}

public MqttServerOptionsBuilder WithSubscriptionInterceptor(IMqttServerSubscriptionInterceptor value)
{
_options.SubscriptionInterceptor = value;
return this;
}

public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action<MqttSubscriptionInterceptorContext> value)
{
_options.SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(value);
return this;
}

public MqttServerOptionsBuilder WithPersistentSessions()
{
_options.EnablePersistentSessions = true;


+ 31
- 0
Source/MQTTnet/Server/MqttServerSubscriptionInterceptorDelegate.cs Zobrazit soubor

@@ -0,0 +1,31 @@
using System;
using System.Threading.Tasks;

namespace MQTTnet.Server
{
public class MqttServerSubscriptionInterceptorDelegate : IMqttServerSubscriptionInterceptor
{
private readonly Func<MqttSubscriptionInterceptorContext, Task> _callback;

public MqttServerSubscriptionInterceptorDelegate(Action<MqttSubscriptionInterceptorContext> callback)
{
if (callback == null) throw new ArgumentNullException(nameof(callback));

_callback = context =>
{
callback(context);
return Task.FromResult(0);
};
}

public MqttServerSubscriptionInterceptorDelegate(Func<MqttSubscriptionInterceptorContext, Task> callback)
{
_callback = callback ?? throw new ArgumentNullException(nameof(callback));
}

public Task InterceptSubscriptionAsync(MqttSubscriptionInterceptorContext context)
{
return _callback(context);
}
}
}

+ 1
- 1
Tests/MQTTnet.Core.Tests/AsyncLockTests.cs Zobrazit soubor

@@ -21,7 +21,7 @@ namespace MQTTnet.Tests
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
threads[i] = Task.Run(async () =>
{
using (var releaser = await @lock.LockAsync(CancellationToken.None))
using (var releaser = await @lock.WaitAsync(CancellationToken.None))
{
var localI = globalI;
await Task.Delay(10); // Increase the chance for wrong data.


+ 2
- 1
Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs Zobrazit soubor

@@ -82,9 +82,10 @@ namespace MQTTnet.Tests
throw new NotSupportedException();
}

public void Stop(MqttClientDisconnectType disconnectType)
public Task StopAsync(MqttClientDisconnectType disconnectType)
{
StopCalledCount++;
return Task.FromResult(0);
}

public Task SubscribeAsync(IList<TopicFilter> topicFilters)


+ 101
- 34
Tests/MQTTnet.Core.Tests/MqttServerTests.cs Zobrazit soubor

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -52,7 +53,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_WillMessage()
public async Task MqttServer_Will_Message()
{
var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -84,7 +85,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_SubscribeUnsubscribe()
public async Task MqttServer_Subscribe_Unsubscribe()
{
var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -230,7 +231,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_SessionTakeover()
public async Task MqttServer_Session_Takeover()
{
var server = new MqttFactory().CreateMqttServer();
try
@@ -299,38 +300,43 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_HandleCleanDisconnect()
public async Task MqttServer_Handle_Clean_Disconnect()
{
var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger());
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
var clientConnectedCalled = 0;
var clientDisconnectedCalled = 0;
var s = new MqttFactory().CreateMqttServer();
try
{
var clientConnectedCalled = 0;
var clientDisconnectedCalled = 0;

s.ClientConnected += (_, __) => clientConnectedCalled++;
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++;
s.ClientConnected += (_, __) => clientConnectedCalled++;
s.ClientDisconnected += (_, __) => clientDisconnectedCalled++;

var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();
var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();

await s.StartAsync(new MqttServerOptions());
await s.StartAsync(new MqttServerOptions());

var c1 = new MqttFactory().CreateMqttClient();
var c1 = new MqttFactory().CreateMqttClient();

await c1.ConnectAsync(clientOptions);
await c1.ConnectAsync(clientOptions);

await Task.Delay(100);
await Task.Delay(100);

await c1.DisconnectAsync();
await c1.DisconnectAsync();

await Task.Delay(100);
await Task.Delay(100);

await s.StopAsync();
await s.StopAsync();

await Task.Delay(100);
await Task.Delay(100);

Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
}
finally
{
await s.StopAsync();
}
}

[TestMethod]
@@ -410,11 +416,13 @@ namespace MQTTnet.Tests
.WithPayload("value" + j).WithRetainFlag().Build()).GetAwaiter().GetResult();
}

Thread.Sleep(100);

client.DisconnectAsync().GetAwaiter().GetResult();
}
});

await Task.Delay(100);
await Task.Delay(1000);

var retainedMessages = server.GetRetainedMessages();

@@ -432,7 +440,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_RetainedMessagesFlow()
public async Task MqttServer_Retained_Messages_Flow()
{
var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build();
var serverAdapter = new TestMqttServerAdapter();
@@ -440,9 +448,9 @@ namespace MQTTnet.Tests
await s.StartAsync(new MqttServerOptions());
var c1 = await serverAdapter.ConnectTestClient("c1");
await c1.PublishAsync(retainedMessage);
Thread.Sleep(500);
await Task.Delay(500);
await c1.DisconnectAsync();
Thread.Sleep(500);
await Task.Delay(500);

var receivedMessages = 0;
var c2 = await serverAdapter.ConnectTestClient("c2");
@@ -468,7 +476,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_NoRetainedMessage()
public async Task MqttServer_No_Retained_Message()
{
var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -498,7 +506,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_RetainedMessage()
public async Task MqttServer_Retained_Message()
{
var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -535,7 +543,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_ClearRetainedMessage()
public async Task MqttServer_Clear_Retained_Message()
{
var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
@@ -567,7 +575,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_PersistRetainedMessage()
public async Task MqttServer_Persist_Retained_Message()
{
var storage = new TestStorage();
var serverAdapter = new TestMqttServerAdapter();
@@ -629,7 +637,7 @@ namespace MQTTnet.Tests

try
{
var options = new MqttServerOptions { ApplicationMessageInterceptor = Interceptor };
var options = new MqttServerOptions { ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(c => Interceptor(c)) };

await s.StartAsync(options);

@@ -692,7 +700,7 @@ namespace MQTTnet.Tests
}

[TestMethod]
public async Task MqttServer_ConnectionDenied()
public async Task MqttServer_Connection_Denied()
{
var server = new MqttFactory().CreateMqttServer();
var client = new MqttFactory().CreateMqttClient();
@@ -791,7 +799,6 @@ namespace MQTTnet.Tests
Assert.AreEqual("cdcd", flow);
}


[TestMethod]
public async Task MqttServer_StopAndRestart()
{
@@ -820,6 +827,66 @@ namespace MQTTnet.Tests
await server.StopAsync();
}

[TestMethod]
public async Task MqttServer_Close_Idle_Connection()
{
var server = new MqttFactory().CreateMqttServer();

try
{
await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build());

var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync("localhost", 1883);

// Don't send anything. The server should close the connection.
await Task.Delay(TimeSpan.FromSeconds(5));

try
{
await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
Assert.Fail("Receive should throw an exception.");
}
catch (SocketException)
{
}
}
finally
{
await server.StopAsync();
}
}

[TestMethod]
public async Task MqttServer_Send_Garbage()
{
var server = new MqttFactory().CreateMqttServer();

try
{
await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build());

var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync("localhost", 1883);
await client.SendAsync(Encoding.UTF8.GetBytes("Garbage"), SocketFlags.None);

await Task.Delay(TimeSpan.FromSeconds(5));

try
{
await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
Assert.Fail("Receive should throw an exception.");
}
catch (SocketException)
{
}
}
finally
{
await server.StopAsync();
}
}

private class TestStorage : IMqttServerStorage
{
public IList<MqttApplicationMessage> Messages = new List<MqttApplicationMessage>();


+ 5
- 5
Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs Zobrazit soubor

@@ -16,7 +16,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();

var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce);
Assert.IsTrue(result.IsSubscribed);
@@ -31,7 +31,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });

sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();

var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed);
@@ -47,7 +47,7 @@ namespace MQTTnet.Tests
sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });
sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce });

sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();

var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce);
Assert.IsTrue(result.IsSubscribed);
@@ -62,7 +62,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();

Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);
}
@@ -75,7 +75,7 @@ namespace MQTTnet.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.Subscribe(sp);
sm.SubscribeAsync(sp).GetAwaiter().GetResult();

Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed);



+ 7
- 6
Tests/MQTTnet.TestApp.NetCore/ServerTest.cs Zobrazit soubor

@@ -23,7 +23,7 @@ namespace MQTTnet.TestApp.NetCore
{
var options = new MqttServerOptions
{
ConnectionValidator = p =>
ConnectionValidator = new MqttServerConnectionValidatorDelegate(p =>
{
if (p.ClientId == "SpecialClient")
{
@@ -32,11 +32,11 @@ namespace MQTTnet.TestApp.NetCore
p.ReturnCode = MqttConnectReturnCode.ConnectionRefusedBadUsernameOrPassword;
}
}
},
}),

Storage = new RetainedMessageHandler(),

ApplicationMessageInterceptor = context =>
ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(context =>
{
if (MqttTopicFilterComparer.IsMatch(context.ApplicationMessage.Topic, "/myTopic/WithTimestamp/#"))
{
@@ -50,8 +50,9 @@ namespace MQTTnet.TestApp.NetCore
context.AcceptPublish = false;
context.CloseConnection = true;
}
},
SubscriptionInterceptor = context =>
}),

SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(context =>
{
if (context.TopicFilter.Topic.StartsWith("admin/foo/bar") && context.ClientId != "theAdmin")
{
@@ -63,7 +64,7 @@ namespace MQTTnet.TestApp.NetCore
context.AcceptSubscription = false;
context.CloseConnection = true;
}
}
})
};

// Extend the timestamp for all messages from clients.


+ 7
- 7
Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs Zobrazit soubor

@@ -450,7 +450,7 @@ namespace MQTTnet.TestApp.UniversalWindows
return;
}

var sessions = _mqttServer.GetClientSessionsStatus();
var sessions = _mqttServer.GetClientSessionsStatusAsync().GetAwaiter().GetResult();
_sessions.Clear();

foreach (var session in sessions)
@@ -568,7 +568,7 @@ namespace MQTTnet.TestApp.UniversalWindows
{
var options = new MqttServerOptions();

options.ConnectionValidator = c =>
options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(c =>
{
if (c.ClientId.Length < 10)
{
@@ -589,7 +589,7 @@ namespace MQTTnet.TestApp.UniversalWindows
}

c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
};
});

var factory = new MqttFactory();
var mqttServer = factory.CreateMqttServer();
@@ -633,7 +633,7 @@ namespace MQTTnet.TestApp.UniversalWindows
{
};

options.ConnectionValidator = c =>
options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(c =>
{
if (c.ClientId != "Highlander")
{
@@ -642,7 +642,7 @@ namespace MQTTnet.TestApp.UniversalWindows
}

c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
};
});

var mqttServer = new MqttFactory().CreateMqttServer();
await mqttServer.StartAsync(optionsBuilder.Build());
@@ -652,7 +652,7 @@ namespace MQTTnet.TestApp.UniversalWindows
// Setup client validator.
var options = new MqttServerOptions
{
ConnectionValidator = c =>
ConnectionValidator = new MqttServerConnectionValidatorDelegate(c =>
{
if (c.ClientId.Length < 10)
{
@@ -673,7 +673,7 @@ namespace MQTTnet.TestApp.UniversalWindows
}

c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
}
})
};
}



Načítá se…
Zrušit
Uložit