Browse Source

Refactor message processing and async/await usage.

release/3.x.x
Christian Kratky 6 years ago
parent
commit
caea7910b4
31 changed files with 586 additions and 634 deletions
  1. +3
    -1
      Build/MQTTnet.Extensions.ManagedClient.nuspec
  2. +3
    -1
      Build/MQTTnet.Extensions.Rpc.nuspec
  3. +3
    -1
      Build/MQTTnet.nuspec
  4. +1
    -1
      Source/MQTTnet/Adapter/IMqttChannelAdapter.cs
  5. +99
    -61
      Source/MQTTnet/Adapter/MqttChannelAdapter.cs
  6. +18
    -20
      Source/MQTTnet/Client/MqttClient.cs
  7. +2
    -0
      Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs
  8. +1
    -1
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  9. +1
    -1
      Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs
  10. +4
    -3
      Source/MQTTnet/Implementations/MqttTcpServerListener.cs
  11. +2
    -2
      Source/MQTTnet/Internal/TaskExtensions.cs
  12. +1
    -4
      Source/MQTTnet/MQTTnet.csproj
  13. +0
    -48
      Source/MQTTnet/Serializer/ByteReader.cs
  14. +0
    -36
      Source/MQTTnet/Serializer/ByteWriter.cs
  15. +9
    -3
      Source/MQTTnet/Serializer/MqttPacketReader.cs
  16. +110
    -112
      Source/MQTTnet/Serializer/MqttPacketSerializer.cs
  17. +99
    -29
      Source/MQTTnet/Serializer/MqttPacketWriter.cs
  18. +4
    -4
      Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs
  19. +19
    -22
      Source/MQTTnet/Server/MqttClientSession.cs
  20. +139
    -123
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  21. +15
    -0
      Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs
  22. +4
    -7
      Source/MQTTnet/Server/MqttServer.cs
  23. +1
    -1
      Source/MQTTnet/Server/PrepareClientSessionResult.cs
  24. +1
    -1
      Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs
  25. +0
    -30
      Tests/MQTTnet.Core.Tests/ByteReaderTests.cs
  26. +0
    -51
      Tests/MQTTnet.Core.Tests/ByteWriterTests.cs
  27. +6
    -6
      Tests/MQTTnet.Core.Tests/ExtensionTests.cs
  28. +1
    -1
      Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs
  29. +1
    -2
      Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs
  30. +37
    -61
      Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs
  31. +2
    -1
      Tests/MQTTnet.TestApp.NetCore/Program.cs

+ 3
- 1
Build/MQTTnet.Extensions.ManagedClient.nuspec View File

@@ -48,6 +48,8 @@

<!-- .NET Framework -->
<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\net452\MQTTnet.Extensions.ManagedClient.*" target="lib\net452\"/>
<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\net461\MQTTnet.Extensions.ManagedClient.*" target="lib\net461\"/>

<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\netstandard2.0\MQTTnet.Extensions.ManagedClient.*" target="lib\net461\"/>
<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\netstandard2.0\MQTTnet.Extensions.ManagedClient.*" target="lib\net472\"/>
</files>
</package>

+ 3
- 1
Build/MQTTnet.Extensions.Rpc.nuspec View File

@@ -48,6 +48,8 @@

<!-- .NET Framework -->
<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\net452\MQTTnet.Extensions.Rpc.*" target="lib\net452\"/>
<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\net461\MQTTnet.Extensions.Rpc.*" target="lib\net461\"/>

<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\netstandard2.0\MQTTnet.Extensions.Rpc.*" target="lib\net461\"/>
<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\netstandard2.0\MQTTnet.Extensions.Rpc.*" target="lib\net462\"/>
</files>
</package>

+ 3
- 1
Build/MQTTnet.nuspec View File

@@ -69,6 +69,8 @@

<!-- .NET Framework -->
<file src="..\Source\MQTTnet\bin\Release\net452\MQTTnet.*" target="lib\net452\"/>
<file src="..\Source\MQTTnet\bin\Release\net461\MQTTnet.*" target="lib\net461\"/>

<file src="..\Source\MQTTnet\bin\Release\netstandard2.0\MQTTnet.*" target="lib\net461\"/>
<file src="..\Source\MQTTnet\bin\Release\netstandard2.0\MQTTnet.*" target="lib\net472\"/>
</files>
</package>

+ 1
- 1
Source/MQTTnet/Adapter/IMqttChannelAdapter.cs View File

@@ -20,7 +20,7 @@ namespace MQTTnet.Adapter

Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken);

Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken);
Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken);

Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken);
}


+ 99
- 61
Source/MQTTnet/Adapter/MqttChannelAdapter.cs View File

@@ -40,52 +40,84 @@ namespace MQTTnet.Adapter
public event EventHandler ReadingPacketStarted;
public event EventHandler ReadingPacketCompleted;

public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken)
public async Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfDisposed();
_logger.Verbose("Connecting [Timeout={0}]", timeout);

return ExecuteAndWrapExceptionAsync(() =>
Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken));
try
{
_logger.Verbose("Connecting [Timeout={0}]", timeout);

await Internal.TaskExtensions
.TimeoutAfterAsync(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)
.ConfigureAwait(false);
}
catch (Exception exception)
{
if (IsWrappedException(exception))
{
throw;
}

WrapException(exception);
}
}

public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken)
public async Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfDisposed();
_logger.Verbose("Disconnecting [Timeout={0}]", timeout);

return ExecuteAndWrapExceptionAsync(() =>
Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, cancellationToken));
try
{
_logger.Verbose("Disconnecting [Timeout={0}]", timeout);

await Internal.TaskExtensions
.TimeoutAfterAsync(ct => _channel.DisconnectAsync(), timeout, cancellationToken)
.ConfigureAwait(false);
}
catch (Exception exception)
{
if (IsWrappedException(exception))
{
throw;
}

WrapException(exception);
}
}

public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken)
public async Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken)
{
return ExecuteAndWrapExceptionAsync(() =>
try
{
_logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout);
_logger.Verbose("TX >>> {0}", packet);

var packetData = PacketSerializer.Serialize(packet);

return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync(
packetData.Array,
packetData.Offset,
packetData.Count,
ct), timeout, cancellationToken);
});
await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false);
}
catch (Exception exception)
{
if (IsWrappedException(exception))
{
throw;
}

WrapException(exception);
}
}

public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfDisposed();

MqttBasePacket packet = null;
await ExecuteAndWrapExceptionAsync(async () =>
try
{
ReceivedMqttPacket receivedMqttPacket;

if (timeout > TimeSpan.Zero)
{
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false);
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfterAsync(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false);
}
else
{
@@ -94,19 +126,30 @@ namespace MQTTnet.Adapter

if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested)
{
return;
return null;
}

packet = PacketSerializer.Deserialize(receivedMqttPacket);
var packet = PacketSerializer.Deserialize(receivedMqttPacket);
if (packet == null)
{
throw new MqttProtocolViolationException("Received malformed packet.");
}

_logger.Verbose("RX <<< {0}", packet);
}).ConfigureAwait(false);
return packet;
}
catch (Exception exception)
{
if (IsWrappedException(exception))
{
throw;
}

WrapException(exception);
}

return packet;
return null;
}

private async Task<ReceivedMqttPacket> ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken)
@@ -138,7 +181,9 @@ namespace MQTTnet.Adapter
chunkSize = bytesLeft;
}

var readBytes = await channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken) .ConfigureAwait(false);
// async/await is not used to avoid the overhead of context switches. We assume that the reamining data
// has been sent from the sender directly after the initial bytes.
var readBytes = channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken).GetAwaiter().GetResult();
if (readBytes <= 0)
{
ExceptionHelper.ThrowGracefulSocketClose();
@@ -155,42 +200,6 @@ namespace MQTTnet.Adapter
}
}

private static async Task ExecuteAndWrapExceptionAsync(Func<Task> action)
{
try
{
await action().ConfigureAwait(false);
}
catch (Exception exception)
{
if (exception is TaskCanceledException ||
exception is OperationCanceledException ||
exception is MqttCommunicationTimedOutException ||
exception is MqttCommunicationException)
{
throw;
}

if (exception is IOException && exception.InnerException is SocketException socketException)
{
if (socketException.SocketErrorCode == SocketError.ConnectionAborted)
{
throw new OperationCanceledException();
}
}

if (exception is COMException comException)
{
if ((uint)comException.HResult == ErrorOperationAborted)
{
throw new OperationCanceledException();
}
}

throw new MqttCommunicationException(exception);
}
}

public void Dispose()
{
_isDisposed = true;
@@ -205,5 +214,34 @@ namespace MQTTnet.Adapter
throw new ObjectDisposedException(nameof(MqttChannelAdapter));
}
}

private static bool IsWrappedException(Exception exception)
{
return exception is TaskCanceledException ||
exception is OperationCanceledException ||
exception is MqttCommunicationTimedOutException ||
exception is MqttCommunicationException;
}

private static void WrapException(Exception exception)
{
if (exception is IOException && exception.InnerException is SocketException socketException)
{
if (socketException.SocketErrorCode == SocketError.ConnectionAborted)
{
throw new OperationCanceledException();
}
}

if (exception is COMException comException)
{
if ((uint)comException.HResult == ErrorOperationAborted)
{
throw new OperationCanceledException();
}
}

throw new MqttCommunicationException(exception);
}
}
}

+ 18
- 20
Source/MQTTnet/Client/MqttClient.cs View File

@@ -17,7 +17,7 @@ namespace MQTTnet.Client
{
private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
private readonly Stopwatch _sendTracker = new Stopwatch();
private readonly SemaphoreSlim _disconnectLock = new SemaphoreSlim(1, 1);
private readonly object _disconnectLock = new object();
private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();

private readonly IMqttClientAdapterFactory _adapterFactory;
@@ -215,7 +215,7 @@ namespace MQTTnet.Client

private async Task DisconnectInternalAsync(Task sender, Exception exception)
{
await InitiateDisconnectAsync().ConfigureAwait(false);
InitiateDisconnect();

var clientWasConnected = IsConnected;
IsConnected = false;
@@ -249,25 +249,23 @@ namespace MQTTnet.Client
}
}

private async Task InitiateDisconnectAsync()
private void InitiateDisconnect()
{
await _disconnectLock.WaitAsync().ConfigureAwait(false);
try
lock (_disconnectLock)
{
if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested)
try
{
return;
}
if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested)
{
return;
}

_cancellationTokenSource.Cancel(false);
}
catch (Exception adapterException)
{
_logger.Warning(adapterException, "Error while initiating disconnect.");
}
finally
{
_disconnectLock.Release();
_cancellationTokenSource.Cancel(false);
}
catch (Exception adapterException)
{
_logger.Warning(adapterException, "Error while initiating disconnect.");
}
}
}

@@ -279,7 +277,7 @@ namespace MQTTnet.Client
}

_sendTracker.Restart();
return _adapter.SendPacketAsync(_options.CommunicationTimeout, packet, cancellationToken);
return _adapter.SendPacketAsync(packet, cancellationToken);
}

private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket
@@ -300,8 +298,8 @@ namespace MQTTnet.Client
var packetAwaiter = _packetDispatcher.AddPacketAwaiter<TResponsePacket>(identifier);
try
{
await _adapter.SendPacketAsync(_options.CommunicationTimeout, requestPacket, cancellationToken).ConfigureAwait(false);
var respone = await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false);
await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false);
var respone = await Internal.TaskExtensions.TimeoutAfterAsync(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false);

return (TResponsePacket)respone;
}


+ 2
- 0
Source/MQTTnet/Diagnostics/TargetFrameworkInfoProvider.cs View File

@@ -10,6 +10,8 @@
return "net452";
#elif NET461
return "net461";
#elif NET472
return "net472";
#elif NETSTANDARD1_3
return "netstandard1.3";
#elif NETSTANDARD2_0


+ 1
- 1
Source/MQTTnet/Implementations/MqttTcpChannel.cs View File

@@ -1,4 +1,4 @@
#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0
#if !WINDOWS_UWP
using System;
using System.Net.Security;
using System.Net.Sockets;


+ 1
- 1
Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs View File

@@ -1,4 +1,4 @@
#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0
#if !WINDOWS_UWP
using System;
using System.Collections.Generic;
using System.Net.Sockets;


+ 4
- 3
Source/MQTTnet/Implementations/MqttTcpServerListener.cs View File

@@ -1,4 +1,4 @@
#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0
#if !WINDOWS_UWP
using System;
using System.Net;
using System.Net.Security;
@@ -76,7 +76,8 @@ namespace MQTTnet.Implementations
await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, SslProtocols.Tls12, false).ConfigureAwait(false);
}

_logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {_addressFamily}'.");
var protocol = _addressFamily == AddressFamily.InterNetwork ? "ipv4" : "ipv6";
_logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {protocol}'.");

var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger);
ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter));
@@ -102,7 +103,7 @@ namespace MQTTnet.Implementations
{
_socket?.Dispose();

#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461
#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461 || NET472
_tlsCertificate?.Dispose();
#endif
}


+ 2
- 2
Source/MQTTnet/Internal/TaskExtensions.cs View File

@@ -7,7 +7,7 @@ namespace MQTTnet.Internal
{
public static class TaskExtensions
{
public static async Task TimeoutAfter(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken)
public static async Task TimeoutAfterAsync(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken)
{
if (action == null) throw new ArgumentNullException(nameof(action));

@@ -31,7 +31,7 @@ namespace MQTTnet.Internal
}
}

public static async Task<TResult> TimeoutAfter<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken)
public static async Task<TResult> TimeoutAfterAsync<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken)
{
if (action == null) throw new ArgumentNullException(nameof(action));



+ 1
- 4
Source/MQTTnet/MQTTnet.csproj View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks Condition=" '$(OS)' == 'Windows_NT' ">netstandard1.3;netstandard2.0;net452;net461;uap10.0</TargetFrameworks>
<TargetFrameworks Condition=" '$(OS)' == 'Windows_NT' ">netstandard1.3;netstandard2.0;net452;uap10.0</TargetFrameworks>
<TargetFrameworks Condition=" '$(OS)' != 'Windows_NT' ">netstandard1.3;netstandard2.0</TargetFrameworks>
<AssemblyName>MQTTnet</AssemblyName>
<RootNamespace>MQTTnet</RootNamespace>
@@ -62,7 +62,4 @@
<ItemGroup Condition="'$(TargetFramework)'=='net452'">
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)'=='net461'">
</ItemGroup>

</Project>

+ 0
- 48
Source/MQTTnet/Serializer/ByteReader.cs View File

@@ -1,48 +0,0 @@
using System;

namespace MQTTnet.Serializer
{
public class ByteReader
{
private readonly int _source;
private int _index;

public ByteReader(int source)
{
_source = source;
}

public bool Read()
{
if (_index >= 8)
{
throw new InvalidOperationException("End of byte reached.");
}

var result = ((1 << _index) & _source) > 0;
_index++;
return result;
}

public int Read(int count)
{
if (_index + count > 8)
{
throw new InvalidOperationException("End of byte will be reached.");
}

var result = 0;
for (var i = 0; i < count; i++)
{
if (((1 << _index) & _source) > 0)
{
result |= 1 << i;
}

_index++;
}

return result;
}
}
}

+ 0
- 36
Source/MQTTnet/Serializer/ByteWriter.cs View File

@@ -1,36 +0,0 @@
using System;

namespace MQTTnet.Serializer
{
public class ByteWriter
{
private int _index;
private int _byte;

public byte Value => (byte)_byte;

public void Write(int @byte, int count)
{
for (var i = 0; i < count; i++)
{
var value = ((1 << i) & @byte) > 0;
Write(value);
}
}

public void Write(bool bit)
{
if (_index >= 8)
{
throw new InvalidOperationException("End of the byte reached.");
}

if (bit)
{
_byte |= 1 << _index;
}

_index++;
}
}
}

+ 9
- 3
Source/MQTTnet/Serializer/MqttPacketReader.cs View File

@@ -12,6 +12,8 @@ namespace MQTTnet.Serializer
{
// The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length.
// So in all cases at least 2 bytes must be read for a complete MQTT packet.
// async/await is used here because the next packet is received in a couple of minutes so the performance
// impact is acceptable according to a useless waiting thread.
var buffer = new byte[2];
var totalBytesRead = 0;

@@ -37,11 +39,11 @@ namespace MQTTnet.Serializer
return new MqttFixedHeader(buffer[0], 0);
}

var bodyLength = await ReadBodyLengthAsync(channel, buffer[1], cancellationToken).ConfigureAwait(false);
var bodyLength = ReadBodyLength(channel, buffer[1], cancellationToken);
return new MqttFixedHeader(buffer[0], bodyLength);
}

private static async Task<int> ReadBodyLengthAsync(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken)
private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken)
{
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html.
var multiplier = 128;
@@ -50,7 +52,11 @@ namespace MQTTnet.Serializer

while ((encodedByte & 128) != 0)
{
encodedByte = await ReadByteAsync(channel, cancellationToken).ConfigureAwait(false);
// Here the async/await pattern is not used becuase the overhead of context switches
// is too big for reading 1 byte in a row. We expect that the remaining data was sent
// directly after the initial bytes. If the client disconnects just in this moment we
// will get an exception anyway.
encodedByte = ReadByteAsync(channel, cancellationToken).GetAwaiter().GetResult();

value += (byte)(encodedByte & 127) * multiplier;
if (multiplier > 128 * 128 * 128)


+ 110
- 112
Source/MQTTnet/Serializer/MqttPacketSerializer.cs View File

@@ -2,7 +2,6 @@
using MQTTnet.Packets;
using MQTTnet.Protocol;
using System;
using System.IO;
using System.Linq;
using MQTTnet.Adapter;

@@ -18,57 +17,46 @@ namespace MQTTnet.Serializer
{
if (packet == null) throw new ArgumentNullException(nameof(packet));

using (var stream = new MemoryStream(128))
{
// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes)
stream.Seek(5, SeekOrigin.Begin);
var packetWriter = new MqttPacketWriter();

var fixedHeader = SerializePacket(packet, stream);
var remainingLength = (int)stream.Length - 5;
// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes)
packetWriter.Seek(5);

var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength);
var fixedHeader = SerializePacket(packet, packetWriter);
var remainingLength = packetWriter.Length - 5;

var headerSize = FixedHeaderSize + remainingLengthBuffer.Count;
var headerOffset = 5 - headerSize;
var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength);

// Position cursor on correct offset on beginining of array (has leading 0x0)
stream.Seek(headerOffset, SeekOrigin.Begin);
stream.WriteByte(fixedHeader);
stream.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count);
var headerSize = FixedHeaderSize + remainingLengthBuffer.Count;
var headerOffset = 5 - headerSize;

#if NET461 || NET452 || NETSTANDARD2_0
var buffer = stream.GetBuffer();
return new ArraySegment<byte>(buffer, headerOffset, (int)stream.Length - headerOffset);
#else
if (stream.TryGetBuffer(out var segment))
{
return new ArraySegment<byte>(segment.Array, headerOffset, segment.Count - headerOffset);
}
// Position cursor on correct offset on beginining of array (has leading 0x0)
packetWriter.Seek(headerOffset);
packetWriter.Write(fixedHeader);
packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count);

var buffer = stream.ToArray();
return new ArraySegment<byte>(buffer, headerOffset, buffer.Length - headerOffset);
#endif
}
var buffer = packetWriter.GetBuffer();
return new ArraySegment<byte>(buffer, headerOffset, packetWriter.Length - headerOffset);
}

private byte SerializePacket(MqttBasePacket packet, Stream stream)
private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter packetWriter)
{
switch (packet)
{
case MqttConnectPacket connectPacket: return Serialize(connectPacket, stream);
case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, stream);
case MqttConnectPacket connectPacket: return Serialize(connectPacket, packetWriter);
case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, packetWriter);
case MqttDisconnectPacket _: return SerializeEmptyPacket(MqttControlPacketType.Disconnect);
case MqttPingReqPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingReq);
case MqttPingRespPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingResp);
case MqttPublishPacket publishPacket: return Serialize(publishPacket, stream);
case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, stream);
case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, stream);
case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, stream);
case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, stream);
case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, stream);
case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, stream);
case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, stream);
case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, stream);
case MqttPublishPacket publishPacket: return Serialize(publishPacket, packetWriter);
case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, packetWriter);
case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, packetWriter);
case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, packetWriter);
case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, packetWriter);
case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, packetWriter);
case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, packetWriter);
case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, packetWriter);
case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, packetWriter);
default: throw new MqttProtocolViolationException("Packet type invalid.");
}
}
@@ -195,10 +183,9 @@ namespace MQTTnet.Serializer
var body = receivedMqttPacket.Body;
ThrowIfBodyIsEmpty(body);

var fixedHeader = new ByteReader(receivedMqttPacket.FixedHeader);
var retain = fixedHeader.Read();
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2);
var dup = fixedHeader.Read();
var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0;
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3);
var dup = (receivedMqttPacket.FixedHeader & 0x3) > 0;

var topic = body.ReadStringWithLengthPrefix();

@@ -253,8 +240,8 @@ namespace MQTTnet.Serializer
throw new MqttProtocolViolationException($"Protocol name ({protocolName}) is not supported.");
}

var connectFlags = new ByteReader(body.ReadByte());
if (connectFlags.Read())
var connectFlags = body.ReadByte();
if ((connectFlags & 0x1) > 0)
{
throw new MqttProtocolViolationException("The first bit of the Connect Flags must be set to 0.");
}
@@ -262,14 +249,14 @@ namespace MQTTnet.Serializer
var packet = new MqttConnectPacket
{
ProtocolVersion = protocolVersion,
CleanSession = connectFlags.Read()
CleanSession = (connectFlags & 0x2) > 0
};

var willFlag = connectFlags.Read();
var willQoS = connectFlags.Read(2);
var willRetain = connectFlags.Read();
var passwordFlag = connectFlags.Read();
var usernameFlag = connectFlags.Read();
var willFlag = (connectFlags & 0x4) > 0;
var willQoS = (connectFlags & 0x18) >> 3;
var willRetain = (connectFlags & 0x20) > 0;
var passwordFlag = (connectFlags & 0x40) > 0;
var usernameFlag = (connectFlags & 0x80) > 0;

packet.KeepAlivePeriod = body.ReadUInt16();
packet.ClientId = body.ReadStringWithLengthPrefix();
@@ -322,11 +309,11 @@ namespace MQTTnet.Serializer

var packet = new MqttConnAckPacket();

var firstByteReader = new ByteReader(body.ReadByte());
var acknowledgeFlags = body.ReadByte();
if (ProtocolVersion == MqttProtocolVersion.V311)
{
packet.IsSessionPresent = firstByteReader.Read();
packet.IsSessionPresent = (acknowledgeFlags & 0x1) > 0;
}

packet.ConnectReturnCode = (MqttConnectReturnCode)body.ReadByte();
@@ -344,119 +331,129 @@ namespace MQTTnet.Serializer
}
}

// ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local
private static void ValidatePublishPacket(MqttPublishPacket packet)
{
if (packet == null) throw new ArgumentNullException(nameof(packet));

if (packet.QualityOfServiceLevel == 0 && packet.Dup)
{
throw new MqttProtocolViolationException("Dup flag must be false for QoS 0 packets [MQTT-3.3.1-2].");
}
}

private byte Serialize(MqttConnectPacket packet, Stream stream)
private byte Serialize(MqttConnectPacket packet, MqttPacketWriter packetWriter)
{
ValidateConnectPacket(packet);

// Write variable header
if (ProtocolVersion == MqttProtocolVersion.V311)
{
stream.WriteWithLengthPrefix("MQTT");
stream.WriteByte(4); // 3.1.2.2 Protocol Level 4
packetWriter.WriteWithLengthPrefix("MQTT");
packetWriter.Write(4); // 3.1.2.2 Protocol Level 4
}
else
{
stream.WriteWithLengthPrefix("MQIsdp");
stream.WriteByte(3); // Protocol Level 3
packetWriter.WriteWithLengthPrefix("MQIsdp");
packetWriter.Write(3); // Protocol Level 3
}

var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags
connectFlags.Write(false); // Reserved
connectFlags.Write(packet.CleanSession);
connectFlags.Write(packet.WillMessage != null);

if (packet.WillMessage != null)
byte connectFlags = 0x0;
if (packet.CleanSession)
{
connectFlags.Write((int)packet.WillMessage.QualityOfServiceLevel, 2);
connectFlags.Write(packet.WillMessage.Retain);
connectFlags |= 0x2;
}
else

if (packet.WillMessage != null)
{
connectFlags.Write(0, 2);
connectFlags.Write(false);
}
connectFlags |= 0x4;
connectFlags |= (byte)((byte)packet.WillMessage.QualityOfServiceLevel << 3);

if (packet.WillMessage.Retain)
{
connectFlags |= 0x20;
}
}
if (packet.Password != null && packet.Username == null)
{
throw new MqttProtocolViolationException("If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22].");
}

connectFlags.Write(packet.Password != null);
connectFlags.Write(packet.Username != null);
if (packet.Password != null)
{
connectFlags |= 0x40;
}

stream.Write(connectFlags);
stream.Write(packet.KeepAlivePeriod);
stream.WriteWithLengthPrefix(packet.ClientId);
if (packet.Username != null)
{
connectFlags |= 0x80;
}
packetWriter.Write(connectFlags);
packetWriter.Write(packet.KeepAlivePeriod);
packetWriter.WriteWithLengthPrefix(packet.ClientId);

if (packet.WillMessage != null)
{
stream.WriteWithLengthPrefix(packet.WillMessage.Topic);
stream.WriteWithLengthPrefix(packet.WillMessage.Payload);
packetWriter.WriteWithLengthPrefix(packet.WillMessage.Topic);
packetWriter.WriteWithLengthPrefix(packet.WillMessage.Payload);
}

if (packet.Username != null)
{
stream.WriteWithLengthPrefix(packet.Username);
packetWriter.WriteWithLengthPrefix(packet.Username);
}

if (packet.Password != null)
{
stream.WriteWithLengthPrefix(packet.Password);
packetWriter.WriteWithLengthPrefix(packet.Password);
}

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect);
}

private byte Serialize(MqttConnAckPacket packet, Stream stream)
private byte Serialize(MqttConnAckPacket packet, MqttPacketWriter packetWriter)
{
if (ProtocolVersion == MqttProtocolVersion.V310)
{
stream.WriteByte(0);
packetWriter.Write(0);
}
else if (ProtocolVersion == MqttProtocolVersion.V311)
{
var connectAcknowledgeFlags = new ByteWriter();
connectAcknowledgeFlags.Write(packet.IsSessionPresent);

stream.Write(connectAcknowledgeFlags);
byte connectAcknowledgeFlags = 0x0;
if (packet.IsSessionPresent)
{
connectAcknowledgeFlags |= 0x1;
}
packetWriter.Write(connectAcknowledgeFlags);
}
else
{
throw new MqttProtocolViolationException("Protocol version not supported.");
}

stream.WriteByte((byte)packet.ConnectReturnCode);
packetWriter.Write((byte)packet.ConnectReturnCode);

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck);
}

private static byte Serialize(MqttPubRelPacket packet, Stream stream)
private static byte Serialize(MqttPubRelPacket packet, MqttPacketWriter packetWriter)
{
if (!packet.PacketIdentifier.HasValue)
{
throw new MqttProtocolViolationException("PubRel packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02);
}

private static byte Serialize(MqttPublishPacket packet, Stream stream)
private static byte Serialize(MqttPublishPacket packet, MqttPacketWriter packetWriter)
{
ValidatePublishPacket(packet);

stream.WriteWithLengthPrefix(packet.Topic);
packetWriter.WriteWithLengthPrefix(packet.Topic);

if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce)
{
@@ -465,7 +462,7 @@ namespace MQTTnet.Serializer
throw new MqttProtocolViolationException("Publish packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);
}
else
{
@@ -477,7 +474,7 @@ namespace MQTTnet.Serializer

if (packet.Payload?.Length > 0)
{
stream.Write(packet.Payload, 0, packet.Payload.Length);
packetWriter.Write(packet.Payload, 0, packet.Payload.Length);
}

byte fixedHeader = 0;
@@ -497,43 +494,43 @@ namespace MQTTnet.Serializer
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader);
}

private static byte Serialize(MqttPubAckPacket packet, Stream stream)
private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter packetWriter)
{
if (!packet.PacketIdentifier.HasValue)
{
throw new MqttProtocolViolationException("PubAck packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck);
}

private static byte Serialize(MqttPubRecPacket packet, Stream stream)
private static byte Serialize(MqttPubRecPacket packet, MqttPacketWriter packetWriter)
{
if (!packet.PacketIdentifier.HasValue)
{
throw new MqttProtocolViolationException("PubRec packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec);
}

private static byte Serialize(MqttPubCompPacket packet, Stream stream)
private static byte Serialize(MqttPubCompPacket packet, MqttPacketWriter packetWriter)
{
if (!packet.PacketIdentifier.HasValue)
{
throw new MqttProtocolViolationException("PubComp packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp);
}

private static byte Serialize(MqttSubscribePacket packet, Stream stream)
private static byte Serialize(MqttSubscribePacket packet, MqttPacketWriter packetWriter)
{
if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3].");

@@ -542,41 +539,41 @@ namespace MQTTnet.Serializer
throw new MqttProtocolViolationException("Subscribe packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);

if (packet.TopicFilters?.Count > 0)
{
foreach (var topicFilter in packet.TopicFilters)
{
stream.WriteWithLengthPrefix(topicFilter.Topic);
stream.WriteByte((byte)topicFilter.QualityOfServiceLevel);
packetWriter.WriteWithLengthPrefix(topicFilter.Topic);
packetWriter.Write((byte)topicFilter.QualityOfServiceLevel);
}
}

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02);
}

private static byte Serialize(MqttSubAckPacket packet, Stream stream)
private static byte Serialize(MqttSubAckPacket packet, MqttPacketWriter packetWriter)
{
if (!packet.PacketIdentifier.HasValue)
{
throw new MqttProtocolViolationException("SubAck packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);

if (packet.SubscribeReturnCodes?.Any() == true)
{
foreach (var packetSubscribeReturnCode in packet.SubscribeReturnCodes)
{
stream.WriteByte((byte)packetSubscribeReturnCode);
packetWriter.Write((byte)packetSubscribeReturnCode);
}
}

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck);
}

private static byte Serialize(MqttUnsubscribePacket packet, Stream stream)
private static byte Serialize(MqttUnsubscribePacket packet, MqttPacketWriter packetWriter)
{
if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2].");

@@ -585,27 +582,27 @@ namespace MQTTnet.Serializer
throw new MqttProtocolViolationException("Unsubscribe packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);

if (packet.TopicFilters?.Any() == true)
{
foreach (var topicFilter in packet.TopicFilters)
{
stream.WriteWithLengthPrefix(topicFilter);
packetWriter.WriteWithLengthPrefix(topicFilter);
}
}

return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02);
}

private static byte Serialize(MqttUnsubAckPacket packet, Stream stream)
private static byte Serialize(MqttUnsubAckPacket packet, MqttPacketWriter packetWriter)
{
if (!packet.PacketIdentifier.HasValue)
{
throw new MqttProtocolViolationException("UnsubAck packet has no packet identifier.");
}

stream.Write(packet.PacketIdentifier.Value);
packetWriter.Write(packet.PacketIdentifier.Value);
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck);
}

@@ -614,6 +611,7 @@ namespace MQTTnet.Serializer
return MqttPacketWriter.BuildFixedHeader(type);
}

// ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local
private static void ThrowIfBodyIsEmpty(MqttPacketBodyReader body)
{
if (body == null || body.Length == 0)


+ 99
- 29
Source/MQTTnet/Serializer/MqttPacketWriter.cs View File

@@ -1,12 +1,23 @@
using System;
using System.IO;
using System.Text;
using MQTTnet.Protocol;

namespace MQTTnet.Serializer
{
public static class MqttPacketWriter
/// <summary>
/// This is a custom implementation of a memory stream which provides only MQTTnet relevant features.
/// The goal is to avoid lots of argument checks like in the original stream. The growth rule is the
/// same as for the original MemoryStream in .net. Also this implementation allows accessing the internal
/// buffer for all platforms and .net framework versions (which is not available at the regular MemoryStream).
/// </summary>
public class MqttPacketWriter
{
private byte[] _buffer = new byte[128];

private int _position;

public int Length { get; private set; }

public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0)
{
var fixedHeader = (int)packetType << 4;
@@ -14,33 +25,6 @@ namespace MQTTnet.Serializer
return (byte)fixedHeader;
}

public static void Write(this Stream stream, ushort value)
{
var buffer = BitConverter.GetBytes(value);
stream.WriteByte(buffer[1]);
stream.WriteByte(buffer[0]);
}

public static void Write(this Stream stream, ByteWriter value)
{
if (value == null) throw new ArgumentNullException(nameof(value));

stream.WriteByte(value.Value);
}

public static void WriteWithLengthPrefix(this Stream stream, string value)
{
stream.WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty));
}

public static void WriteWithLengthPrefix(this Stream stream, byte[] value)
{
var length = (ushort)value.Length;

stream.Write(length);
stream.Write(value, 0, length);
}

public static ArraySegment<byte> EncodeRemainingLength(int length)
{
// write the encoded remaining length right aligned on the 4 byte buffer
@@ -69,5 +53,91 @@ namespace MQTTnet.Serializer

return new ArraySegment<byte>(buffer, 0, bufferOffset);
}

public void WriteWithLengthPrefix(string value)
{
WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty));
}

public void WriteWithLengthPrefix(byte[] value)
{
EnsureAdditionalCapacity(value.Length + 2);

Write((ushort)value.Length);
Write(value, 0, value.Length);
}
public void Write(byte @byte)
{
EnsureAdditionalCapacity(1);

_buffer[_position] = @byte;
IncreasePostition(1);
}

public void Write(ushort value)
{
EnsureAdditionalCapacity(2);

_buffer[_position] = (byte)(value >> 8);
IncreasePostition(1);
_buffer[_position] = (byte)value;
IncreasePostition(1);
}

public void Write(byte[] array, int offset, int count)
{
EnsureAdditionalCapacity(count);

Array.Copy(array, offset, _buffer, _position, count);
IncreasePostition(count);
}

public void Seek(int offset)
{
EnsureCapacity(offset);
_position = offset;
}

public byte[] GetBuffer()
{
return _buffer;
}

private void EnsureAdditionalCapacity(int additionalCapacity)
{
var freeSpace = _buffer.Length - _position;
if (freeSpace >= additionalCapacity)
{
return;
}

EnsureCapacity(additionalCapacity - freeSpace);
}

private void EnsureCapacity(int capacity)
{
if (_buffer.Length >= capacity)
{
return;
}

var newBufferLength = _buffer.Length;
while (newBufferLength < capacity)
{
newBufferLength *= 2;
}

Array.Resize(ref _buffer, newBufferLength);
}

private void IncreasePostition(int length)
{
_position += length;
if (_position > Length)
{
Length = _position;
}
}
}
}

Source/MQTTnet/Server/MqttClientPendingMessagesQueue.cs → Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs View File

@@ -11,7 +11,7 @@ using MQTTnet.Protocol;

namespace MQTTnet.Server
{
public class MqttClientPendingMessagesQueue : IDisposable
public class MqttClientPendingPacketsQueue : IDisposable
{
private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent();
private readonly IMqttServerOptions _options;
@@ -20,13 +20,13 @@ namespace MQTTnet.Server

private ConcurrentQueue<MqttBasePacket> _queue = new ConcurrentQueue<MqttBasePacket>();

public MqttClientPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger)
public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger)
{
if (logger == null) throw new ArgumentNullException(nameof(logger));
_options = options ?? throw new ArgumentNullException(nameof(options));
_clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession));

_logger = logger.CreateChildLogger(nameof(MqttClientPendingMessagesQueue));
_logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue));
}

public int Count => _queue.Count;
@@ -115,7 +115,7 @@ namespace MQTTnet.Server
return;
}

await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, packet, cancellationToken).ConfigureAwait(false);
await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false);

_logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId);
}

+ 19
- 22
Source/MQTTnet/Server/MqttClientSession.cs View File

@@ -18,7 +18,7 @@ namespace MQTTnet.Server

private readonly MqttRetainedMessagesManager _retainedMessagesManager;
private readonly MqttClientKeepAliveMonitor _keepAliveMonitor;
private readonly MqttClientPendingMessagesQueue _pendingMessagesQueue;
private readonly MqttClientPendingPacketsQueue _pendingPacketsQueue;
private readonly MqttClientSubscriptionsManager _subscriptionsManager;
private readonly MqttClientSessionsManager _sessionsManager;

@@ -49,7 +49,7 @@ namespace MQTTnet.Server

_keepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, () => Stop(MqttClientDisconnectType.NotClean), _logger);
_subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server);
_pendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger);
_pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger);
}

public string ClientId { get; }
@@ -60,7 +60,7 @@ namespace MQTTnet.Server
status.IsConnected = _adapter != null;
status.Endpoint = _adapter?.Endpoint;
status.ProtocolVersion = _adapter?.PacketSerializer?.ProtocolVersion;
status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count;
status.PendingApplicationMessagesCount = _pendingPacketsQueue.Count;
status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived;
status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived;
}
@@ -80,7 +80,7 @@ namespace MQTTnet.Server
_wasCleanDisconnect = false;
_willMessage = connectPacket.WillMessage;

_pendingMessagesQueue.Start(adapter, _cancellationTokenSource.Token);
_pendingPacketsQueue.Start(adapter, _cancellationTokenSource.Token);
_keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token);

while (!_cancellationTokenSource.IsCancellationRequested)
@@ -149,13 +149,10 @@ namespace MQTTnet.Server

if (_willMessage != null && !_wasCleanDisconnect)
{
_sessionsManager.StartDispatchApplicationMessage(this, _willMessage);
_sessionsManager.EnqueueApplicationMessage(this, _willMessage);
}

_willMessage = null;

////_pendingMessagesQueue.WaitForCompletion();
////_keepAliveMonitor.WaitForCompletion();
}
finally
{
@@ -196,7 +193,7 @@ namespace MQTTnet.Server
}
}
_pendingMessagesQueue.Enqueue(publishPacket);
_pendingPacketsQueue.Enqueue(publishPacket);
}

public Task SubscribeAsync(IList<TopicFilter> topicFilters)
@@ -226,12 +223,12 @@ namespace MQTTnet.Server

public void ClearPendingApplicationMessages()
{
_pendingMessagesQueue.Clear();
_pendingPacketsQueue.Clear();
}

public void Dispose()
{
_pendingMessagesQueue?.Dispose();
_pendingPacketsQueue?.Dispose();

_cancellationTokenSource?.Dispose();
}
@@ -245,7 +242,7 @@ namespace MQTTnet.Server

if (packet is MqttPingReqPacket)
{
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, new MqttPingRespPacket(), cancellationToken);
return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken);
}

if (packet is MqttPubRelPacket pubRelPacket)
@@ -260,7 +257,7 @@ namespace MQTTnet.Server
PacketIdentifier = pubRecPacket.PacketIdentifier
};

return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, responsePacket, cancellationToken);
return adapter.SendPacketAsync(responsePacket, cancellationToken);
}

if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
@@ -308,7 +305,7 @@ namespace MQTTnet.Server
private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
{
var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket);
await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false);
await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false);

if (subscribeResult.CloseConnection)
{
@@ -322,7 +319,7 @@ namespace MQTTnet.Server
private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
{
var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket);
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, unsubscribeResult, cancellationToken);
return adapter.SendPacketAsync(unsubscribeResult, cancellationToken);
}

private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
@@ -333,7 +330,7 @@ namespace MQTTnet.Server
{
case MqttQualityOfServiceLevel.AtMostOnce:
{
_sessionsManager.StartDispatchApplicationMessage(this, applicationMessage);
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage);
return Task.FromResult(0);
}
case MqttQualityOfServiceLevel.AtLeastOnce:
@@ -353,25 +350,25 @@ namespace MQTTnet.Server

private Task HandleIncomingPublishPacketWithQoS1(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{
_sessionsManager.StartDispatchApplicationMessage(this, applicationMessage);
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage);

var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier };
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken);
return adapter.SendPacketAsync(response, cancellationToken);
}

private Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{
// QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery)
_sessionsManager.StartDispatchApplicationMessage(this, applicationMessage);
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage);

var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier };
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken);
return adapter.SendPacketAsync(response, cancellationToken);
}

private Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken)
private static Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken)
{
var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier };
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken);
return adapter.SendPacketAsync(response, cancellationToken);
}

private void OnAdapterReadingPacketCompleted(object sender, EventArgs e)


+ 139
- 123
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -6,27 +6,29 @@ using System.Threading.Tasks;
using MQTTnet.Adapter;
using MQTTnet.Diagnostics;
using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;

namespace MQTTnet.Server
{
public class MqttClientSessionsManager : IDisposable
public class MqttClientSessionsManager
{
private readonly BlockingCollection<MqttEnqueuedApplicationMessage> _messageQueue = new BlockingCollection<MqttEnqueuedApplicationMessage>();
private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>();
private readonly AsyncLock _sessionPreparationLock = new AsyncLock();

private readonly CancellationToken _cancellationToken;

private readonly MqttRetainedMessagesManager _retainedMessagesManager;
private readonly IMqttServerOptions _options;
private readonly IMqttNetChildLogger _logger;

public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetChildLogger logger)
public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, CancellationToken cancellationToken, IMqttNetChildLogger logger)
{
if (logger == null) throw new ArgumentNullException(nameof(logger));

_logger = logger.CreateChildLogger(nameof(MqttClientSessionsManager));

_cancellationToken = cancellationToken;
_options = options ?? throw new ArgumentNullException(nameof(options));
Server = server ?? throw new ArgumentNullException(nameof(server));
_retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager));
@@ -34,7 +36,129 @@ namespace MQTTnet.Server

public MqttServer Server { get; }

public async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
public void Start()
{
Task.Factory.StartNew(() => ProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default);
}

public Task StopAsync()
{
foreach (var session in _sessions)
{
session.Value.Stop(MqttClientDisconnectType.NotClean);
}

_sessions.Clear();
return Task.FromResult(0);
}

public Task StartSession(IMqttChannelAdapter clientAdapter)
{
return Task.Run(() => RunSession(clientAdapter, _cancellationToken), _cancellationToken);
}

public Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync()
{
var result = new List<IMqttClientSessionStatus>();
foreach (var session in _sessions)
{
var status = new MqttClientSessionStatus(this, session.Value);
session.Value.FillStatus(status);

result.Add(status);
}

return Task.FromResult((IList<IMqttClientSessionStatus>)result);
}

public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
{
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));

_messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken);
}

public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
}

return session.SubscribeAsync(topicFilters);
}

public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
}

return session.UnsubscribeAsync(topicFilters);
}

public void DeleteSession(string clientId)
{
_sessions.TryRemove(clientId, out _);
_logger.Verbose("Session for client '{0}' deleted.", clientId);
}

private void ProcessQueuedApplicationMessages(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
var enqueuedApplicationMessage = _messageQueue.Take(cancellationToken);
var sender = enqueuedApplicationMessage.Sender;
var applicationMessage = enqueuedApplicationMessage.ApplicationMessage;

var interceptorContext = InterceptApplicationMessage(sender, applicationMessage);
if (interceptorContext != null)
{
if (interceptorContext.CloseConnection)
{
enqueuedApplicationMessage.Sender.Stop(MqttClientDisconnectType.NotClean);
}

if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish)
{
return;
}

applicationMessage = interceptorContext.ApplicationMessage;
}

Server.OnApplicationMessageReceived(sender?.ClientId, applicationMessage);

if (applicationMessage.Retain)
{
_retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult();
}

foreach (var clientSession in _sessions.Values)
{
clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage);
}
}
catch (TaskCanceledException)
{
}
catch (Exception exception)
{
_logger.Error(exception, "Unhandled exception while processing queued application message.");
}
}
}

private async Task RunSession(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
{
var clientId = string.Empty;
var wasCleanDisconnect = false;
@@ -60,7 +184,7 @@ namespace MQTTnet.Server
var connectReturnCode = ValidateConnection(connectPacket);
if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
{
await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout,
await clientAdapter.SendPacketAsync(
new MqttConnAckPacket
{
ConnectReturnCode = connectReturnCode
@@ -70,15 +194,15 @@ namespace MQTTnet.Server
return;
}

var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);
var result = PrepareClientSession(connectPacket);
var clientSession = result.Session;

await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout,
await clientAdapter.SendPacketAsync(
new MqttConnAckPacket
{
ConnectReturnCode = connectReturnCode,
IsSessionPresent = result.IsExistingSession
},
},
cancellationToken).ConfigureAwait(false);

Server.OnClientConnected(clientId);
@@ -113,73 +237,6 @@ namespace MQTTnet.Server
}
}

public Task StopAsync()
{
foreach (var session in _sessions)
{
session.Value.Stop(MqttClientDisconnectType.NotClean);
}

_sessions.Clear();
return Task.FromResult(0);
}

public Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync()
{
var result = new List<IMqttClientSessionStatus>();
foreach (var session in _sessions)
{
var status = new MqttClientSessionStatus(this, session.Value);
session.Value.FillStatus(status);

result.Add(status);
}

return Task.FromResult((IList<IMqttClientSessionStatus>)result);
}

public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
{
Task.Run(() => DispatchApplicationMessageAsync(senderClientSession, applicationMessage));
}

public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
}

return session.SubscribeAsync(topicFilters);
}

public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
}

return session.UnsubscribeAsync(topicFilters);
}

public void DeleteSession(string clientId)
{
_sessions.TryRemove(clientId, out _);
_logger.Verbose("Session for client '{0}' deleted.", clientId);
}

public void Dispose()
{
_sessionPreparationLock?.Dispose();
}

private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket)
{
if (_options.ConnectionValidator == null)
@@ -197,9 +254,9 @@ namespace MQTTnet.Server
return context.ReturnCode;
}

private async Task<GetOrCreateClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
private PrepareClientSessionResult PrepareClientSession(MqttConnectPacket connectPacket)
{
using (await _sessionPreparationLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
lock (_sessions)
{
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession);
if (isSessionPresent)
@@ -231,60 +288,19 @@ namespace MQTTnet.Server
_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId);
}

return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession };
return new PrepareClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession };
}
}

private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage)
{
try
{
var interceptorContext = InterceptApplicationMessage(senderClientSession, applicationMessage);
if (interceptorContext != null)
{
if (interceptorContext.CloseConnection)
{
senderClientSession.Stop(MqttClientDisconnectType.NotClean);
}

if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish)
{
return;
}

applicationMessage = interceptorContext.ApplicationMessage;
}

Server.OnApplicationMessageReceived(senderClientSession?.ClientId, applicationMessage);

if (applicationMessage.Retain)
{
await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false);
}

foreach (var clientSession in _sessions.Values)
{
clientSession.EnqueueApplicationMessage(senderClientSession, applicationMessage);
}
}
catch (Exception exception)
{
_logger.Error(exception, "Error while processing application message");
}
}

private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
{
var interceptorContext = new MqttApplicationMessageInterceptorContext(
senderClientSession?.ClientId,
applicationMessage);

var interceptor = _options.ApplicationMessageInterceptor;
if (interceptor == null)
{
return interceptorContext;
return null;
}

var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage);
interceptor(interceptorContext);
return interceptorContext;
}


+ 15
- 0
Source/MQTTnet/Server/MqttEnqueuedApplicationMessage.cs View File

@@ -0,0 +1,15 @@
namespace MQTTnet.Server
{
public class MqttEnqueuedApplicationMessage
{
public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage)
{
Sender = sender;
ApplicationMessage = applicationMessage;
}

public MqttClientSession Sender { get; }

public MqttApplicationMessage ApplicationMessage { get; }
}
}

+ 4
- 7
Source/MQTTnet/Server/MqttServer.cs View File

@@ -65,7 +65,7 @@ namespace MQTTnet.Server

if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started.");

_clientSessionsManager.StartDispatchApplicationMessage(null, applicationMessage);
_clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage);

return Task.FromResult(0);
}
@@ -81,7 +81,8 @@ namespace MQTTnet.Server
_retainedMessagesManager = new MqttRetainedMessagesManager(Options, _logger);
await _retainedMessagesManager.LoadMessagesAsync().ConfigureAwait(false);

_clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _logger);
_clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _cancellationTokenSource.Token, _logger);
_clientSessionsManager.Start();

foreach (var adapter in _adapters)
{
@@ -118,8 +119,6 @@ namespace MQTTnet.Server
}
finally
{
_clientSessionsManager?.Dispose();
_cancellationTokenSource = null;
_retainedMessagesManager = null;
_clientSessionsManager = null;
@@ -155,9 +154,7 @@ namespace MQTTnet.Server

private void OnClientAccepted(object sender, MqttServerAdapterClientAcceptedEventArgs eventArgs)
{
eventArgs.SessionTask = Task.Run(
() => _clientSessionsManager.RunSessionAsync(eventArgs.Client, _cancellationTokenSource.Token),
_cancellationTokenSource.Token);
eventArgs.SessionTask = _clientSessionsManager.StartSession(eventArgs.Client);
}
}
}

Source/MQTTnet/Server/GetOrCreateClientSessionResult.cs → Source/MQTTnet/Server/PrepareClientSessionResult.cs View File

@@ -1,6 +1,6 @@
namespace MQTTnet.Server
{
public class GetOrCreateClientSessionResult
public class PrepareClientSessionResult
{
public bool IsExistingSession { get; set; }


+ 1
- 1
Tests/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs View File

@@ -65,7 +65,7 @@ namespace MQTTnet.Benchmarks

for (var i = 0; i < 10000; i++)
{
_channelAdapter.SendPacketAsync(TimeSpan.FromSeconds(15), _packet, CancellationToken.None).GetAwaiter().GetResult();
_channelAdapter.SendPacketAsync(_packet, CancellationToken.None).GetAwaiter().GetResult();
}

_stream.Position = 0;


+ 0
- 30
Tests/MQTTnet.Core.Tests/ByteReaderTests.cs View File

@@ -1,30 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Serializer;

namespace MQTTnet.Core.Tests
{
[TestClass]
public class ByteReaderTests
{
[TestMethod]
public void ByteReader_ReadToEnd()
{
var reader = new ByteReader(85);
Assert.IsTrue(reader.Read());
Assert.IsFalse(reader.Read());
Assert.IsTrue(reader.Read());
Assert.IsFalse(reader.Read());
Assert.IsTrue(reader.Read());
Assert.IsFalse(reader.Read());
Assert.IsTrue(reader.Read());
Assert.IsFalse(reader.Read());
}

[TestMethod]
public void ByteReader_ReadPartial()
{
var reader = new ByteReader(15);
Assert.AreEqual(3, reader.Read(2));
}
}
}

+ 0
- 51
Tests/MQTTnet.Core.Tests/ByteWriterTests.cs View File

@@ -1,51 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Serializer;

namespace MQTTnet.Core.Tests
{
[TestClass]
public class ByteWriterTests
{
[TestMethod]
public void ByteWriter_WriteMultipleAll()
{
var b = new ByteWriter();
Assert.AreEqual(0, b.Value);
b.Write(3, 2);
Assert.AreEqual(3, b.Value);
}

[TestMethod]
public void ByteWriter_WriteMultiplePartial()
{
var b = new ByteWriter();
Assert.AreEqual(0, b.Value);
b.Write(255, 2);
Assert.AreEqual(3, b.Value);
}

[TestMethod]
public void ByteWriter_WriteTo0xFF()
{
var b = new ByteWriter();

Assert.AreEqual(0, b.Value);
b.Write(true);
Assert.AreEqual(1, b.Value);
b.Write(true);
Assert.AreEqual(3, b.Value);
b.Write(true);
Assert.AreEqual(7, b.Value);
b.Write(true);
Assert.AreEqual(15, b.Value);
b.Write(true);
Assert.AreEqual(31, b.Value);
b.Write(true);
Assert.AreEqual(63, b.Value);
b.Write(true);
Assert.AreEqual(127, b.Value);
b.Write(true);
Assert.AreEqual(255, b.Value);
}
}
}

+ 6
- 6
Tests/MQTTnet.Core.Tests/ExtensionTests.cs View File

@@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests
[TestMethod]
public async Task TimeoutAfter()
{
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
}

[ExpectedException(typeof(MqttCommunicationTimedOutException))]
[TestMethod]
public async Task TimeoutAfterWithResult()
{
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None);
}

[TestMethod]
public async Task TimeoutAfterCompleteInTime()
{
var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None);
var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None);
Assert.AreEqual(5, result);
}

@@ -36,7 +36,7 @@ namespace MQTTnet.Core.Tests
{
try
{
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() =>
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() =>
{
var iis = new int[0];
iis[1] = 0;
@@ -55,7 +55,7 @@ namespace MQTTnet.Core.Tests
{
try
{
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() =>
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() =>
{
var iis = new int[0];
iis[1] = 0;
@@ -76,7 +76,7 @@ namespace MQTTnet.Core.Tests
var tasks = Enumerable.Range(0, 100000)
.Select(i =>
{
return MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None);
return MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None);
});

await Task.WhenAll(tasks);


+ 1
- 1
Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs View File

@@ -11,7 +11,7 @@ namespace MQTTnet.Core.Tests
public class MqttPacketReaderTests
{
[TestMethod]
[ExpectedException(typeof(MqttCommunicationException))]
[ExpectedException(typeof(MqttCommunicationClosedGracefullyException))]
public void MqttPacketReader_EmptyStream()
{
MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult();


+ 1
- 2
Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs View File

@@ -1,6 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Adapter;
@@ -36,7 +35,7 @@ namespace MQTTnet.Core.Tests
return Task.FromResult(0);
}

public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken)
public Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken)
{
ThrowIfPartnerIsNull();



+ 37
- 61
Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs View File

@@ -12,16 +12,46 @@ namespace MQTTnet.TestApp.NetCore
{
public static class PerformanceTest
{
public static async Task RunAsync()
public static void Run()
{
Console.WriteLine("Press 'c' for concurrent sends. Otherwise in one batch.");
var concurrent = Console.ReadKey(true).KeyChar == 'c';
try
{
var mqttServer = new MqttFactory().CreateMqttServer();
mqttServer.StartAsync(new MqttServerOptions()).GetAwaiter().GetResult();

var server = Task.Run(RunServerAsync);
await Task.Delay(1000);
var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10), concurrent));
var options = new MqttClientOptions
{
ChannelOptions = new MqttClientTcpOptions
{
Server = "127.0.0.1"
},
CleanSession = true
};

var client = new MqttFactory().CreateMqttClient();
client.ConnectAsync(options).GetAwaiter().GetResult();

var message = CreateMessage();
var stopwatch = new Stopwatch();

await Task.WhenAll(server, client).ConfigureAwait(false);
for (var i = 0; i < 10; i++)
{
stopwatch.Restart();

var sentMessagesCount = 0;
while (stopwatch.ElapsedMilliseconds < 1000)
{
client.PublishAsync(message).GetAwaiter().GetResult();
sentMessagesCount++;
}

Console.WriteLine($"Sending {sentMessagesCount} messages per second. #" + (i + 1));
}
}
catch (Exception exception)
{
Console.WriteLine(exception);
}
}

private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval, bool concurrent)
@@ -53,29 +83,8 @@ namespace MQTTnet.TestApp.NetCore
}

var message = CreateMessage();
var messages = new[] { message };

var stopwatch = Stopwatch.StartNew();

var sentMessagesCount = 0;
while (stopwatch.ElapsedMilliseconds < 1000)
{
client.PublishAsync(messages).GetAwaiter().GetResult();
sentMessagesCount++;
}

Console.WriteLine($"Sending {sentMessagesCount} messages per second. #1");

sentMessagesCount = 0;
stopwatch.Restart();
while (stopwatch.ElapsedMilliseconds < 1000)
{
await client.PublishAsync(messages).ConfigureAwait(false);
sentMessagesCount++;
}

Console.WriteLine($"Sending {sentMessagesCount} messages per second. #2");

var testMessageCount = 10000;
for (var i = 0; i < testMessageCount; i++)
{
@@ -142,38 +151,5 @@ namespace MQTTnet.TestApp.NetCore
Interlocked.Increment(ref count);
return Task.Run(() => client.PublishAsync(applicationMessage));
}

private static async Task RunServerAsync()
{
try
{
var mqttServer = new MqttFactory().CreateMqttServer();

////var msgs = 0;
////var stopwatch = Stopwatch.StartNew();
////mqttServer.ApplicationMessageReceived += (sender, args) =>
////{
//// msgs++;
//// if (stopwatch.ElapsedMilliseconds > 1000)
//// {
//// Console.WriteLine($"received {msgs}");
//// msgs = 0;
//// stopwatch.Restart();
//// }
////};
await mqttServer.StartAsync(new MqttServerOptions());

Console.WriteLine("Press any key to exit.");
Console.ReadLine();

await mqttServer.StopAsync().ConfigureAwait(false);
}
catch (Exception e)
{
Console.WriteLine(e);
}

Console.ReadLine();
}
}
}

+ 2
- 1
Tests/MQTTnet.TestApp.NetCore/Program.cs View File

@@ -34,7 +34,8 @@ namespace MQTTnet.TestApp.NetCore
}
else if (pressedKey.KeyChar == '3')
{
Task.Run(PerformanceTest.RunAsync);
PerformanceTest.Run();
return;
}
else if (pressedKey.KeyChar == '4')
{


Loading…
Cancel
Save