@@ -48,6 +48,8 @@ | |||
<!-- .NET Framework --> | |||
<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\net452\MQTTnet.Extensions.ManagedClient.*" target="lib\net452\"/> | |||
<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\net461\MQTTnet.Extensions.ManagedClient.*" target="lib\net461\"/> | |||
<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\netstandard2.0\MQTTnet.Extensions.ManagedClient.*" target="lib\net461\"/> | |||
<file src="..\Source\MQTTnet.Extensions.ManagedClient\bin\Release\netstandard2.0\MQTTnet.Extensions.ManagedClient.*" target="lib\net472\"/> | |||
</files> | |||
</package> |
@@ -48,6 +48,8 @@ | |||
<!-- .NET Framework --> | |||
<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\net452\MQTTnet.Extensions.Rpc.*" target="lib\net452\"/> | |||
<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\net461\MQTTnet.Extensions.Rpc.*" target="lib\net461\"/> | |||
<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\netstandard2.0\MQTTnet.Extensions.Rpc.*" target="lib\net461\"/> | |||
<file src="..\Source\MQTTnet.Extensions.Rpc\bin\Release\netstandard2.0\MQTTnet.Extensions.Rpc.*" target="lib\net462\"/> | |||
</files> | |||
</package> |
@@ -69,6 +69,8 @@ | |||
<!-- .NET Framework --> | |||
<file src="..\Source\MQTTnet\bin\Release\net452\MQTTnet.*" target="lib\net452\"/> | |||
<file src="..\Source\MQTTnet\bin\Release\net461\MQTTnet.*" target="lib\net461\"/> | |||
<file src="..\Source\MQTTnet\bin\Release\netstandard2.0\MQTTnet.*" target="lib\net461\"/> | |||
<file src="..\Source\MQTTnet\bin\Release\netstandard2.0\MQTTnet.*" target="lib\net472\"/> | |||
</files> | |||
</package> |
@@ -20,7 +20,7 @@ namespace MQTTnet.Adapter | |||
Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken); | |||
Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken); | |||
Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken); | |||
Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); | |||
} | |||
@@ -40,52 +40,84 @@ namespace MQTTnet.Adapter | |||
public event EventHandler ReadingPacketStarted; | |||
public event EventHandler ReadingPacketCompleted; | |||
public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
public async Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
ThrowIfDisposed(); | |||
_logger.Verbose("Connecting [Timeout={0}]", timeout); | |||
return ExecuteAndWrapExceptionAsync(() => | |||
Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); | |||
try | |||
{ | |||
_logger.Verbose("Connecting [Timeout={0}]", timeout); | |||
await Internal.TaskExtensions | |||
.TimeoutAfterAsync(ct => _channel.ConnectAsync(ct), timeout, cancellationToken) | |||
.ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
if (IsWrappedException(exception)) | |||
{ | |||
throw; | |||
} | |||
WrapException(exception); | |||
} | |||
} | |||
public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
public async Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
ThrowIfDisposed(); | |||
_logger.Verbose("Disconnecting [Timeout={0}]", timeout); | |||
return ExecuteAndWrapExceptionAsync(() => | |||
Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, cancellationToken)); | |||
try | |||
{ | |||
_logger.Verbose("Disconnecting [Timeout={0}]", timeout); | |||
await Internal.TaskExtensions | |||
.TimeoutAfterAsync(ct => _channel.DisconnectAsync(), timeout, cancellationToken) | |||
.ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
if (IsWrappedException(exception)) | |||
{ | |||
throw; | |||
} | |||
WrapException(exception); | |||
} | |||
} | |||
public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken) | |||
public async Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) | |||
{ | |||
return ExecuteAndWrapExceptionAsync(() => | |||
try | |||
{ | |||
_logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); | |||
_logger.Verbose("TX >>> {0}", packet); | |||
var packetData = PacketSerializer.Serialize(packet); | |||
return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( | |||
packetData.Array, | |||
packetData.Offset, | |||
packetData.Count, | |||
ct), timeout, cancellationToken); | |||
}); | |||
await _channel.WriteAsync(packetData.Array, packetData.Offset, packetData.Count, cancellationToken).ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
if (IsWrappedException(exception)) | |||
{ | |||
throw; | |||
} | |||
WrapException(exception); | |||
} | |||
} | |||
public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
ThrowIfDisposed(); | |||
MqttBasePacket packet = null; | |||
await ExecuteAndWrapExceptionAsync(async () => | |||
try | |||
{ | |||
ReceivedMqttPacket receivedMqttPacket; | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); | |||
receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfterAsync(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
@@ -94,19 +126,30 @@ namespace MQTTnet.Adapter | |||
if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) | |||
{ | |||
return; | |||
return null; | |||
} | |||
packet = PacketSerializer.Deserialize(receivedMqttPacket); | |||
var packet = PacketSerializer.Deserialize(receivedMqttPacket); | |||
if (packet == null) | |||
{ | |||
throw new MqttProtocolViolationException("Received malformed packet."); | |||
} | |||
_logger.Verbose("RX <<< {0}", packet); | |||
}).ConfigureAwait(false); | |||
return packet; | |||
} | |||
catch (Exception exception) | |||
{ | |||
if (IsWrappedException(exception)) | |||
{ | |||
throw; | |||
} | |||
WrapException(exception); | |||
} | |||
return packet; | |||
return null; | |||
} | |||
private async Task<ReceivedMqttPacket> ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||
@@ -138,7 +181,9 @@ namespace MQTTnet.Adapter | |||
chunkSize = bytesLeft; | |||
} | |||
var readBytes = await channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken) .ConfigureAwait(false); | |||
// async/await is not used to avoid the overhead of context switches. We assume that the reamining data | |||
// has been sent from the sender directly after the initial bytes. | |||
var readBytes = channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken).GetAwaiter().GetResult(); | |||
if (readBytes <= 0) | |||
{ | |||
ExceptionHelper.ThrowGracefulSocketClose(); | |||
@@ -155,42 +200,6 @@ namespace MQTTnet.Adapter | |||
} | |||
} | |||
private static async Task ExecuteAndWrapExceptionAsync(Func<Task> action) | |||
{ | |||
try | |||
{ | |||
await action().ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
if (exception is TaskCanceledException || | |||
exception is OperationCanceledException || | |||
exception is MqttCommunicationTimedOutException || | |||
exception is MqttCommunicationException) | |||
{ | |||
throw; | |||
} | |||
if (exception is IOException && exception.InnerException is SocketException socketException) | |||
{ | |||
if (socketException.SocketErrorCode == SocketError.ConnectionAborted) | |||
{ | |||
throw new OperationCanceledException(); | |||
} | |||
} | |||
if (exception is COMException comException) | |||
{ | |||
if ((uint)comException.HResult == ErrorOperationAborted) | |||
{ | |||
throw new OperationCanceledException(); | |||
} | |||
} | |||
throw new MqttCommunicationException(exception); | |||
} | |||
} | |||
public void Dispose() | |||
{ | |||
_isDisposed = true; | |||
@@ -205,5 +214,34 @@ namespace MQTTnet.Adapter | |||
throw new ObjectDisposedException(nameof(MqttChannelAdapter)); | |||
} | |||
} | |||
private static bool IsWrappedException(Exception exception) | |||
{ | |||
return exception is TaskCanceledException || | |||
exception is OperationCanceledException || | |||
exception is MqttCommunicationTimedOutException || | |||
exception is MqttCommunicationException; | |||
} | |||
private static void WrapException(Exception exception) | |||
{ | |||
if (exception is IOException && exception.InnerException is SocketException socketException) | |||
{ | |||
if (socketException.SocketErrorCode == SocketError.ConnectionAborted) | |||
{ | |||
throw new OperationCanceledException(); | |||
} | |||
} | |||
if (exception is COMException comException) | |||
{ | |||
if ((uint)comException.HResult == ErrorOperationAborted) | |||
{ | |||
throw new OperationCanceledException(); | |||
} | |||
} | |||
throw new MqttCommunicationException(exception); | |||
} | |||
} | |||
} |
@@ -17,7 +17,7 @@ namespace MQTTnet.Client | |||
{ | |||
private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); | |||
private readonly Stopwatch _sendTracker = new Stopwatch(); | |||
private readonly SemaphoreSlim _disconnectLock = new SemaphoreSlim(1, 1); | |||
private readonly object _disconnectLock = new object(); | |||
private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); | |||
private readonly IMqttClientAdapterFactory _adapterFactory; | |||
@@ -215,7 +215,7 @@ namespace MQTTnet.Client | |||
private async Task DisconnectInternalAsync(Task sender, Exception exception) | |||
{ | |||
await InitiateDisconnectAsync().ConfigureAwait(false); | |||
InitiateDisconnect(); | |||
var clientWasConnected = IsConnected; | |||
IsConnected = false; | |||
@@ -249,25 +249,23 @@ namespace MQTTnet.Client | |||
} | |||
} | |||
private async Task InitiateDisconnectAsync() | |||
private void InitiateDisconnect() | |||
{ | |||
await _disconnectLock.WaitAsync().ConfigureAwait(false); | |||
try | |||
lock (_disconnectLock) | |||
{ | |||
if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested) | |||
try | |||
{ | |||
return; | |||
} | |||
if (_cancellationTokenSource == null || _cancellationTokenSource.IsCancellationRequested) | |||
{ | |||
return; | |||
} | |||
_cancellationTokenSource.Cancel(false); | |||
} | |||
catch (Exception adapterException) | |||
{ | |||
_logger.Warning(adapterException, "Error while initiating disconnect."); | |||
} | |||
finally | |||
{ | |||
_disconnectLock.Release(); | |||
_cancellationTokenSource.Cancel(false); | |||
} | |||
catch (Exception adapterException) | |||
{ | |||
_logger.Warning(adapterException, "Error while initiating disconnect."); | |||
} | |||
} | |||
} | |||
@@ -279,7 +277,7 @@ namespace MQTTnet.Client | |||
} | |||
_sendTracker.Restart(); | |||
return _adapter.SendPacketAsync(_options.CommunicationTimeout, packet, cancellationToken); | |||
return _adapter.SendPacketAsync(packet, cancellationToken); | |||
} | |||
private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket | |||
@@ -300,8 +298,8 @@ namespace MQTTnet.Client | |||
var packetAwaiter = _packetDispatcher.AddPacketAwaiter<TResponsePacket>(identifier); | |||
try | |||
{ | |||
await _adapter.SendPacketAsync(_options.CommunicationTimeout, requestPacket, cancellationToken).ConfigureAwait(false); | |||
var respone = await Internal.TaskExtensions.TimeoutAfter(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); | |||
await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false); | |||
var respone = await Internal.TaskExtensions.TimeoutAfterAsync(ct => packetAwaiter.Task, _options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); | |||
return (TResponsePacket)respone; | |||
} | |||
@@ -10,6 +10,8 @@ | |||
return "net452"; | |||
#elif NET461 | |||
return "net461"; | |||
#elif NET472 | |||
return "net472"; | |||
#elif NETSTANDARD1_3 | |||
return "netstandard1.3"; | |||
#elif NETSTANDARD2_0 | |||
@@ -1,4 +1,4 @@ | |||
#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 | |||
#if !WINDOWS_UWP | |||
using System; | |||
using System.Net.Security; | |||
using System.Net.Sockets; | |||
@@ -1,4 +1,4 @@ | |||
#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 | |||
#if !WINDOWS_UWP | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Net.Sockets; | |||
@@ -1,4 +1,4 @@ | |||
#if NET452 || NET461 || NETSTANDARD1_3 || NETSTANDARD2_0 | |||
#if !WINDOWS_UWP | |||
using System; | |||
using System.Net; | |||
using System.Net.Security; | |||
@@ -76,7 +76,8 @@ namespace MQTTnet.Implementations | |||
await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, SslProtocols.Tls12, false).ConfigureAwait(false); | |||
} | |||
_logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {_addressFamily}'."); | |||
var protocol = _addressFamily == AddressFamily.InterNetwork ? "ipv4" : "ipv6"; | |||
_logger.Verbose($"Client '{clientSocket.RemoteEndPoint}' accepted by TCP listener '{_socket.LocalEndPoint}, {protocol}'."); | |||
var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger); | |||
ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); | |||
@@ -102,7 +103,7 @@ namespace MQTTnet.Implementations | |||
{ | |||
_socket?.Dispose(); | |||
#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461 | |||
#if NETSTANDARD1_3 || NETSTANDARD2_0 || NET461 || NET472 | |||
_tlsCertificate?.Dispose(); | |||
#endif | |||
} | |||
@@ -7,7 +7,7 @@ namespace MQTTnet.Internal | |||
{ | |||
public static class TaskExtensions | |||
{ | |||
public static async Task TimeoutAfter(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken) | |||
public static async Task TimeoutAfterAsync(Func<CancellationToken, Task> action, TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
if (action == null) throw new ArgumentNullException(nameof(action)); | |||
@@ -31,7 +31,7 @@ namespace MQTTnet.Internal | |||
} | |||
} | |||
public static async Task<TResult> TimeoutAfter<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken) | |||
public static async Task<TResult> TimeoutAfterAsync<TResult>(Func<CancellationToken, Task<TResult>> action, TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
if (action == null) throw new ArgumentNullException(nameof(action)); | |||
@@ -1,7 +1,7 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFrameworks Condition=" '$(OS)' == 'Windows_NT' ">netstandard1.3;netstandard2.0;net452;net461;uap10.0</TargetFrameworks> | |||
<TargetFrameworks Condition=" '$(OS)' == 'Windows_NT' ">netstandard1.3;netstandard2.0;net452;uap10.0</TargetFrameworks> | |||
<TargetFrameworks Condition=" '$(OS)' != 'Windows_NT' ">netstandard1.3;netstandard2.0</TargetFrameworks> | |||
<AssemblyName>MQTTnet</AssemblyName> | |||
<RootNamespace>MQTTnet</RootNamespace> | |||
@@ -62,7 +62,4 @@ | |||
<ItemGroup Condition="'$(TargetFramework)'=='net452'"> | |||
</ItemGroup> | |||
<ItemGroup Condition="'$(TargetFramework)'=='net461'"> | |||
</ItemGroup> | |||
</Project> |
@@ -1,48 +0,0 @@ | |||
using System; | |||
namespace MQTTnet.Serializer | |||
{ | |||
public class ByteReader | |||
{ | |||
private readonly int _source; | |||
private int _index; | |||
public ByteReader(int source) | |||
{ | |||
_source = source; | |||
} | |||
public bool Read() | |||
{ | |||
if (_index >= 8) | |||
{ | |||
throw new InvalidOperationException("End of byte reached."); | |||
} | |||
var result = ((1 << _index) & _source) > 0; | |||
_index++; | |||
return result; | |||
} | |||
public int Read(int count) | |||
{ | |||
if (_index + count > 8) | |||
{ | |||
throw new InvalidOperationException("End of byte will be reached."); | |||
} | |||
var result = 0; | |||
for (var i = 0; i < count; i++) | |||
{ | |||
if (((1 << _index) & _source) > 0) | |||
{ | |||
result |= 1 << i; | |||
} | |||
_index++; | |||
} | |||
return result; | |||
} | |||
} | |||
} |
@@ -1,36 +0,0 @@ | |||
using System; | |||
namespace MQTTnet.Serializer | |||
{ | |||
public class ByteWriter | |||
{ | |||
private int _index; | |||
private int _byte; | |||
public byte Value => (byte)_byte; | |||
public void Write(int @byte, int count) | |||
{ | |||
for (var i = 0; i < count; i++) | |||
{ | |||
var value = ((1 << i) & @byte) > 0; | |||
Write(value); | |||
} | |||
} | |||
public void Write(bool bit) | |||
{ | |||
if (_index >= 8) | |||
{ | |||
throw new InvalidOperationException("End of the byte reached."); | |||
} | |||
if (bit) | |||
{ | |||
_byte |= 1 << _index; | |||
} | |||
_index++; | |||
} | |||
} | |||
} |
@@ -12,6 +12,8 @@ namespace MQTTnet.Serializer | |||
{ | |||
// The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. | |||
// So in all cases at least 2 bytes must be read for a complete MQTT packet. | |||
// async/await is used here because the next packet is received in a couple of minutes so the performance | |||
// impact is acceptable according to a useless waiting thread. | |||
var buffer = new byte[2]; | |||
var totalBytesRead = 0; | |||
@@ -37,11 +39,11 @@ namespace MQTTnet.Serializer | |||
return new MqttFixedHeader(buffer[0], 0); | |||
} | |||
var bodyLength = await ReadBodyLengthAsync(channel, buffer[1], cancellationToken).ConfigureAwait(false); | |||
var bodyLength = ReadBodyLength(channel, buffer[1], cancellationToken); | |||
return new MqttFixedHeader(buffer[0], bodyLength); | |||
} | |||
private static async Task<int> ReadBodyLengthAsync(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) | |||
private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) | |||
{ | |||
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | |||
var multiplier = 128; | |||
@@ -50,7 +52,11 @@ namespace MQTTnet.Serializer | |||
while ((encodedByte & 128) != 0) | |||
{ | |||
encodedByte = await ReadByteAsync(channel, cancellationToken).ConfigureAwait(false); | |||
// Here the async/await pattern is not used becuase the overhead of context switches | |||
// is too big for reading 1 byte in a row. We expect that the remaining data was sent | |||
// directly after the initial bytes. If the client disconnects just in this moment we | |||
// will get an exception anyway. | |||
encodedByte = ReadByteAsync(channel, cancellationToken).GetAwaiter().GetResult(); | |||
value += (byte)(encodedByte & 127) * multiplier; | |||
if (multiplier > 128 * 128 * 128) | |||
@@ -2,7 +2,6 @@ | |||
using MQTTnet.Packets; | |||
using MQTTnet.Protocol; | |||
using System; | |||
using System.IO; | |||
using System.Linq; | |||
using MQTTnet.Adapter; | |||
@@ -18,57 +17,46 @@ namespace MQTTnet.Serializer | |||
{ | |||
if (packet == null) throw new ArgumentNullException(nameof(packet)); | |||
using (var stream = new MemoryStream(128)) | |||
{ | |||
// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) | |||
stream.Seek(5, SeekOrigin.Begin); | |||
var packetWriter = new MqttPacketWriter(); | |||
var fixedHeader = SerializePacket(packet, stream); | |||
var remainingLength = (int)stream.Length - 5; | |||
// Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) | |||
packetWriter.Seek(5); | |||
var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength); | |||
var fixedHeader = SerializePacket(packet, packetWriter); | |||
var remainingLength = packetWriter.Length - 5; | |||
var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; | |||
var headerOffset = 5 - headerSize; | |||
var remainingLengthBuffer = MqttPacketWriter.EncodeRemainingLength(remainingLength); | |||
// Position cursor on correct offset on beginining of array (has leading 0x0) | |||
stream.Seek(headerOffset, SeekOrigin.Begin); | |||
stream.WriteByte(fixedHeader); | |||
stream.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); | |||
var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; | |||
var headerOffset = 5 - headerSize; | |||
#if NET461 || NET452 || NETSTANDARD2_0 | |||
var buffer = stream.GetBuffer(); | |||
return new ArraySegment<byte>(buffer, headerOffset, (int)stream.Length - headerOffset); | |||
#else | |||
if (stream.TryGetBuffer(out var segment)) | |||
{ | |||
return new ArraySegment<byte>(segment.Array, headerOffset, segment.Count - headerOffset); | |||
} | |||
// Position cursor on correct offset on beginining of array (has leading 0x0) | |||
packetWriter.Seek(headerOffset); | |||
packetWriter.Write(fixedHeader); | |||
packetWriter.Write(remainingLengthBuffer.Array, remainingLengthBuffer.Offset, remainingLengthBuffer.Count); | |||
var buffer = stream.ToArray(); | |||
return new ArraySegment<byte>(buffer, headerOffset, buffer.Length - headerOffset); | |||
#endif | |||
} | |||
var buffer = packetWriter.GetBuffer(); | |||
return new ArraySegment<byte>(buffer, headerOffset, packetWriter.Length - headerOffset); | |||
} | |||
private byte SerializePacket(MqttBasePacket packet, Stream stream) | |||
private byte SerializePacket(MqttBasePacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
switch (packet) | |||
{ | |||
case MqttConnectPacket connectPacket: return Serialize(connectPacket, stream); | |||
case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, stream); | |||
case MqttConnectPacket connectPacket: return Serialize(connectPacket, packetWriter); | |||
case MqttConnAckPacket connAckPacket: return Serialize(connAckPacket, packetWriter); | |||
case MqttDisconnectPacket _: return SerializeEmptyPacket(MqttControlPacketType.Disconnect); | |||
case MqttPingReqPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingReq); | |||
case MqttPingRespPacket _: return SerializeEmptyPacket(MqttControlPacketType.PingResp); | |||
case MqttPublishPacket publishPacket: return Serialize(publishPacket, stream); | |||
case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, stream); | |||
case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, stream); | |||
case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, stream); | |||
case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, stream); | |||
case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, stream); | |||
case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, stream); | |||
case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, stream); | |||
case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, stream); | |||
case MqttPublishPacket publishPacket: return Serialize(publishPacket, packetWriter); | |||
case MqttPubAckPacket pubAckPacket: return Serialize(pubAckPacket, packetWriter); | |||
case MqttPubRecPacket pubRecPacket: return Serialize(pubRecPacket, packetWriter); | |||
case MqttPubRelPacket pubRelPacket: return Serialize(pubRelPacket, packetWriter); | |||
case MqttPubCompPacket pubCompPacket: return Serialize(pubCompPacket, packetWriter); | |||
case MqttSubscribePacket subscribePacket: return Serialize(subscribePacket, packetWriter); | |||
case MqttSubAckPacket subAckPacket: return Serialize(subAckPacket, packetWriter); | |||
case MqttUnsubscribePacket unsubscribePacket: return Serialize(unsubscribePacket, packetWriter); | |||
case MqttUnsubAckPacket unsubAckPacket: return Serialize(unsubAckPacket, packetWriter); | |||
default: throw new MqttProtocolViolationException("Packet type invalid."); | |||
} | |||
} | |||
@@ -195,10 +183,9 @@ namespace MQTTnet.Serializer | |||
var body = receivedMqttPacket.Body; | |||
ThrowIfBodyIsEmpty(body); | |||
var fixedHeader = new ByteReader(receivedMqttPacket.FixedHeader); | |||
var retain = fixedHeader.Read(); | |||
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)fixedHeader.Read(2); | |||
var dup = fixedHeader.Read(); | |||
var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; | |||
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); | |||
var dup = (receivedMqttPacket.FixedHeader & 0x3) > 0; | |||
var topic = body.ReadStringWithLengthPrefix(); | |||
@@ -253,8 +240,8 @@ namespace MQTTnet.Serializer | |||
throw new MqttProtocolViolationException($"Protocol name ({protocolName}) is not supported."); | |||
} | |||
var connectFlags = new ByteReader(body.ReadByte()); | |||
if (connectFlags.Read()) | |||
var connectFlags = body.ReadByte(); | |||
if ((connectFlags & 0x1) > 0) | |||
{ | |||
throw new MqttProtocolViolationException("The first bit of the Connect Flags must be set to 0."); | |||
} | |||
@@ -262,14 +249,14 @@ namespace MQTTnet.Serializer | |||
var packet = new MqttConnectPacket | |||
{ | |||
ProtocolVersion = protocolVersion, | |||
CleanSession = connectFlags.Read() | |||
CleanSession = (connectFlags & 0x2) > 0 | |||
}; | |||
var willFlag = connectFlags.Read(); | |||
var willQoS = connectFlags.Read(2); | |||
var willRetain = connectFlags.Read(); | |||
var passwordFlag = connectFlags.Read(); | |||
var usernameFlag = connectFlags.Read(); | |||
var willFlag = (connectFlags & 0x4) > 0; | |||
var willQoS = (connectFlags & 0x18) >> 3; | |||
var willRetain = (connectFlags & 0x20) > 0; | |||
var passwordFlag = (connectFlags & 0x40) > 0; | |||
var usernameFlag = (connectFlags & 0x80) > 0; | |||
packet.KeepAlivePeriod = body.ReadUInt16(); | |||
packet.ClientId = body.ReadStringWithLengthPrefix(); | |||
@@ -322,11 +309,11 @@ namespace MQTTnet.Serializer | |||
var packet = new MqttConnAckPacket(); | |||
var firstByteReader = new ByteReader(body.ReadByte()); | |||
var acknowledgeFlags = body.ReadByte(); | |||
if (ProtocolVersion == MqttProtocolVersion.V311) | |||
{ | |||
packet.IsSessionPresent = firstByteReader.Read(); | |||
packet.IsSessionPresent = (acknowledgeFlags & 0x1) > 0; | |||
} | |||
packet.ConnectReturnCode = (MqttConnectReturnCode)body.ReadByte(); | |||
@@ -344,119 +331,129 @@ namespace MQTTnet.Serializer | |||
} | |||
} | |||
// ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local | |||
private static void ValidatePublishPacket(MqttPublishPacket packet) | |||
{ | |||
if (packet == null) throw new ArgumentNullException(nameof(packet)); | |||
if (packet.QualityOfServiceLevel == 0 && packet.Dup) | |||
{ | |||
throw new MqttProtocolViolationException("Dup flag must be false for QoS 0 packets [MQTT-3.3.1-2]."); | |||
} | |||
} | |||
private byte Serialize(MqttConnectPacket packet, Stream stream) | |||
private byte Serialize(MqttConnectPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
ValidateConnectPacket(packet); | |||
// Write variable header | |||
if (ProtocolVersion == MqttProtocolVersion.V311) | |||
{ | |||
stream.WriteWithLengthPrefix("MQTT"); | |||
stream.WriteByte(4); // 3.1.2.2 Protocol Level 4 | |||
packetWriter.WriteWithLengthPrefix("MQTT"); | |||
packetWriter.Write(4); // 3.1.2.2 Protocol Level 4 | |||
} | |||
else | |||
{ | |||
stream.WriteWithLengthPrefix("MQIsdp"); | |||
stream.WriteByte(3); // Protocol Level 3 | |||
packetWriter.WriteWithLengthPrefix("MQIsdp"); | |||
packetWriter.Write(3); // Protocol Level 3 | |||
} | |||
var connectFlags = new ByteWriter(); // 3.1.2.3 Connect Flags | |||
connectFlags.Write(false); // Reserved | |||
connectFlags.Write(packet.CleanSession); | |||
connectFlags.Write(packet.WillMessage != null); | |||
if (packet.WillMessage != null) | |||
byte connectFlags = 0x0; | |||
if (packet.CleanSession) | |||
{ | |||
connectFlags.Write((int)packet.WillMessage.QualityOfServiceLevel, 2); | |||
connectFlags.Write(packet.WillMessage.Retain); | |||
connectFlags |= 0x2; | |||
} | |||
else | |||
if (packet.WillMessage != null) | |||
{ | |||
connectFlags.Write(0, 2); | |||
connectFlags.Write(false); | |||
} | |||
connectFlags |= 0x4; | |||
connectFlags |= (byte)((byte)packet.WillMessage.QualityOfServiceLevel << 3); | |||
if (packet.WillMessage.Retain) | |||
{ | |||
connectFlags |= 0x20; | |||
} | |||
} | |||
if (packet.Password != null && packet.Username == null) | |||
{ | |||
throw new MqttProtocolViolationException("If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22]."); | |||
} | |||
connectFlags.Write(packet.Password != null); | |||
connectFlags.Write(packet.Username != null); | |||
if (packet.Password != null) | |||
{ | |||
connectFlags |= 0x40; | |||
} | |||
stream.Write(connectFlags); | |||
stream.Write(packet.KeepAlivePeriod); | |||
stream.WriteWithLengthPrefix(packet.ClientId); | |||
if (packet.Username != null) | |||
{ | |||
connectFlags |= 0x80; | |||
} | |||
packetWriter.Write(connectFlags); | |||
packetWriter.Write(packet.KeepAlivePeriod); | |||
packetWriter.WriteWithLengthPrefix(packet.ClientId); | |||
if (packet.WillMessage != null) | |||
{ | |||
stream.WriteWithLengthPrefix(packet.WillMessage.Topic); | |||
stream.WriteWithLengthPrefix(packet.WillMessage.Payload); | |||
packetWriter.WriteWithLengthPrefix(packet.WillMessage.Topic); | |||
packetWriter.WriteWithLengthPrefix(packet.WillMessage.Payload); | |||
} | |||
if (packet.Username != null) | |||
{ | |||
stream.WriteWithLengthPrefix(packet.Username); | |||
packetWriter.WriteWithLengthPrefix(packet.Username); | |||
} | |||
if (packet.Password != null) | |||
{ | |||
stream.WriteWithLengthPrefix(packet.Password); | |||
packetWriter.WriteWithLengthPrefix(packet.Password); | |||
} | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Connect); | |||
} | |||
private byte Serialize(MqttConnAckPacket packet, Stream stream) | |||
private byte Serialize(MqttConnAckPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (ProtocolVersion == MqttProtocolVersion.V310) | |||
{ | |||
stream.WriteByte(0); | |||
packetWriter.Write(0); | |||
} | |||
else if (ProtocolVersion == MqttProtocolVersion.V311) | |||
{ | |||
var connectAcknowledgeFlags = new ByteWriter(); | |||
connectAcknowledgeFlags.Write(packet.IsSessionPresent); | |||
stream.Write(connectAcknowledgeFlags); | |||
byte connectAcknowledgeFlags = 0x0; | |||
if (packet.IsSessionPresent) | |||
{ | |||
connectAcknowledgeFlags |= 0x1; | |||
} | |||
packetWriter.Write(connectAcknowledgeFlags); | |||
} | |||
else | |||
{ | |||
throw new MqttProtocolViolationException("Protocol version not supported."); | |||
} | |||
stream.WriteByte((byte)packet.ConnectReturnCode); | |||
packetWriter.Write((byte)packet.ConnectReturnCode); | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.ConnAck); | |||
} | |||
private static byte Serialize(MqttPubRelPacket packet, Stream stream) | |||
private static byte Serialize(MqttPubRelPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.PacketIdentifier.HasValue) | |||
{ | |||
throw new MqttProtocolViolationException("PubRel packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRel, 0x02); | |||
} | |||
private static byte Serialize(MqttPublishPacket packet, Stream stream) | |||
private static byte Serialize(MqttPublishPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
ValidatePublishPacket(packet); | |||
stream.WriteWithLengthPrefix(packet.Topic); | |||
packetWriter.WriteWithLengthPrefix(packet.Topic); | |||
if (packet.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) | |||
{ | |||
@@ -465,7 +462,7 @@ namespace MQTTnet.Serializer | |||
throw new MqttProtocolViolationException("Publish packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
} | |||
else | |||
{ | |||
@@ -477,7 +474,7 @@ namespace MQTTnet.Serializer | |||
if (packet.Payload?.Length > 0) | |||
{ | |||
stream.Write(packet.Payload, 0, packet.Payload.Length); | |||
packetWriter.Write(packet.Payload, 0, packet.Payload.Length); | |||
} | |||
byte fixedHeader = 0; | |||
@@ -497,43 +494,43 @@ namespace MQTTnet.Serializer | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); | |||
} | |||
private static byte Serialize(MqttPubAckPacket packet, Stream stream) | |||
private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.PacketIdentifier.HasValue) | |||
{ | |||
throw new MqttProtocolViolationException("PubAck packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubAck); | |||
} | |||
private static byte Serialize(MqttPubRecPacket packet, Stream stream) | |||
private static byte Serialize(MqttPubRecPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.PacketIdentifier.HasValue) | |||
{ | |||
throw new MqttProtocolViolationException("PubRec packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubRec); | |||
} | |||
private static byte Serialize(MqttPubCompPacket packet, Stream stream) | |||
private static byte Serialize(MqttPubCompPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.PacketIdentifier.HasValue) | |||
{ | |||
throw new MqttProtocolViolationException("PubComp packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.PubComp); | |||
} | |||
private static byte Serialize(MqttSubscribePacket packet, Stream stream) | |||
private static byte Serialize(MqttSubscribePacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3]."); | |||
@@ -542,41 +539,41 @@ namespace MQTTnet.Serializer | |||
throw new MqttProtocolViolationException("Subscribe packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
if (packet.TopicFilters?.Count > 0) | |||
{ | |||
foreach (var topicFilter in packet.TopicFilters) | |||
{ | |||
stream.WriteWithLengthPrefix(topicFilter.Topic); | |||
stream.WriteByte((byte)topicFilter.QualityOfServiceLevel); | |||
packetWriter.WriteWithLengthPrefix(topicFilter.Topic); | |||
packetWriter.Write((byte)topicFilter.QualityOfServiceLevel); | |||
} | |||
} | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Subscribe, 0x02); | |||
} | |||
private static byte Serialize(MqttSubAckPacket packet, Stream stream) | |||
private static byte Serialize(MqttSubAckPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.PacketIdentifier.HasValue) | |||
{ | |||
throw new MqttProtocolViolationException("SubAck packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
if (packet.SubscribeReturnCodes?.Any() == true) | |||
{ | |||
foreach (var packetSubscribeReturnCode in packet.SubscribeReturnCodes) | |||
{ | |||
stream.WriteByte((byte)packetSubscribeReturnCode); | |||
packetWriter.Write((byte)packetSubscribeReturnCode); | |||
} | |||
} | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.SubAck); | |||
} | |||
private static byte Serialize(MqttUnsubscribePacket packet, Stream stream) | |||
private static byte Serialize(MqttUnsubscribePacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.TopicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2]."); | |||
@@ -585,27 +582,27 @@ namespace MQTTnet.Serializer | |||
throw new MqttProtocolViolationException("Unsubscribe packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
if (packet.TopicFilters?.Any() == true) | |||
{ | |||
foreach (var topicFilter in packet.TopicFilters) | |||
{ | |||
stream.WriteWithLengthPrefix(topicFilter); | |||
packetWriter.WriteWithLengthPrefix(topicFilter); | |||
} | |||
} | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Unsubscibe, 0x02); | |||
} | |||
private static byte Serialize(MqttUnsubAckPacket packet, Stream stream) | |||
private static byte Serialize(MqttUnsubAckPacket packet, MqttPacketWriter packetWriter) | |||
{ | |||
if (!packet.PacketIdentifier.HasValue) | |||
{ | |||
throw new MqttProtocolViolationException("UnsubAck packet has no packet identifier."); | |||
} | |||
stream.Write(packet.PacketIdentifier.Value); | |||
packetWriter.Write(packet.PacketIdentifier.Value); | |||
return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.UnsubAck); | |||
} | |||
@@ -614,6 +611,7 @@ namespace MQTTnet.Serializer | |||
return MqttPacketWriter.BuildFixedHeader(type); | |||
} | |||
// ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local | |||
private static void ThrowIfBodyIsEmpty(MqttPacketBodyReader body) | |||
{ | |||
if (body == null || body.Length == 0) | |||
@@ -1,12 +1,23 @@ | |||
using System; | |||
using System.IO; | |||
using System.Text; | |||
using MQTTnet.Protocol; | |||
namespace MQTTnet.Serializer | |||
{ | |||
public static class MqttPacketWriter | |||
/// <summary> | |||
/// This is a custom implementation of a memory stream which provides only MQTTnet relevant features. | |||
/// The goal is to avoid lots of argument checks like in the original stream. The growth rule is the | |||
/// same as for the original MemoryStream in .net. Also this implementation allows accessing the internal | |||
/// buffer for all platforms and .net framework versions (which is not available at the regular MemoryStream). | |||
/// </summary> | |||
public class MqttPacketWriter | |||
{ | |||
private byte[] _buffer = new byte[128]; | |||
private int _position; | |||
public int Length { get; private set; } | |||
public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0) | |||
{ | |||
var fixedHeader = (int)packetType << 4; | |||
@@ -14,33 +25,6 @@ namespace MQTTnet.Serializer | |||
return (byte)fixedHeader; | |||
} | |||
public static void Write(this Stream stream, ushort value) | |||
{ | |||
var buffer = BitConverter.GetBytes(value); | |||
stream.WriteByte(buffer[1]); | |||
stream.WriteByte(buffer[0]); | |||
} | |||
public static void Write(this Stream stream, ByteWriter value) | |||
{ | |||
if (value == null) throw new ArgumentNullException(nameof(value)); | |||
stream.WriteByte(value.Value); | |||
} | |||
public static void WriteWithLengthPrefix(this Stream stream, string value) | |||
{ | |||
stream.WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); | |||
} | |||
public static void WriteWithLengthPrefix(this Stream stream, byte[] value) | |||
{ | |||
var length = (ushort)value.Length; | |||
stream.Write(length); | |||
stream.Write(value, 0, length); | |||
} | |||
public static ArraySegment<byte> EncodeRemainingLength(int length) | |||
{ | |||
// write the encoded remaining length right aligned on the 4 byte buffer | |||
@@ -69,5 +53,91 @@ namespace MQTTnet.Serializer | |||
return new ArraySegment<byte>(buffer, 0, bufferOffset); | |||
} | |||
public void WriteWithLengthPrefix(string value) | |||
{ | |||
WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); | |||
} | |||
public void WriteWithLengthPrefix(byte[] value) | |||
{ | |||
EnsureAdditionalCapacity(value.Length + 2); | |||
Write((ushort)value.Length); | |||
Write(value, 0, value.Length); | |||
} | |||
public void Write(byte @byte) | |||
{ | |||
EnsureAdditionalCapacity(1); | |||
_buffer[_position] = @byte; | |||
IncreasePostition(1); | |||
} | |||
public void Write(ushort value) | |||
{ | |||
EnsureAdditionalCapacity(2); | |||
_buffer[_position] = (byte)(value >> 8); | |||
IncreasePostition(1); | |||
_buffer[_position] = (byte)value; | |||
IncreasePostition(1); | |||
} | |||
public void Write(byte[] array, int offset, int count) | |||
{ | |||
EnsureAdditionalCapacity(count); | |||
Array.Copy(array, offset, _buffer, _position, count); | |||
IncreasePostition(count); | |||
} | |||
public void Seek(int offset) | |||
{ | |||
EnsureCapacity(offset); | |||
_position = offset; | |||
} | |||
public byte[] GetBuffer() | |||
{ | |||
return _buffer; | |||
} | |||
private void EnsureAdditionalCapacity(int additionalCapacity) | |||
{ | |||
var freeSpace = _buffer.Length - _position; | |||
if (freeSpace >= additionalCapacity) | |||
{ | |||
return; | |||
} | |||
EnsureCapacity(additionalCapacity - freeSpace); | |||
} | |||
private void EnsureCapacity(int capacity) | |||
{ | |||
if (_buffer.Length >= capacity) | |||
{ | |||
return; | |||
} | |||
var newBufferLength = _buffer.Length; | |||
while (newBufferLength < capacity) | |||
{ | |||
newBufferLength *= 2; | |||
} | |||
Array.Resize(ref _buffer, newBufferLength); | |||
} | |||
private void IncreasePostition(int length) | |||
{ | |||
_position += length; | |||
if (_position > Length) | |||
{ | |||
Length = _position; | |||
} | |||
} | |||
} | |||
} |
@@ -11,7 +11,7 @@ using MQTTnet.Protocol; | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttClientPendingMessagesQueue : IDisposable | |||
public class MqttClientPendingPacketsQueue : IDisposable | |||
{ | |||
private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent(); | |||
private readonly IMqttServerOptions _options; | |||
@@ -20,13 +20,13 @@ namespace MQTTnet.Server | |||
private ConcurrentQueue<MqttBasePacket> _queue = new ConcurrentQueue<MqttBasePacket>(); | |||
public MqttClientPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) | |||
public MqttClientPendingPacketsQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) | |||
{ | |||
if (logger == null) throw new ArgumentNullException(nameof(logger)); | |||
_options = options ?? throw new ArgumentNullException(nameof(options)); | |||
_clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); | |||
_logger = logger.CreateChildLogger(nameof(MqttClientPendingMessagesQueue)); | |||
_logger = logger.CreateChildLogger(nameof(MqttClientPendingPacketsQueue)); | |||
} | |||
public int Count => _queue.Count; | |||
@@ -115,7 +115,7 @@ namespace MQTTnet.Server | |||
return; | |||
} | |||
await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, packet, cancellationToken).ConfigureAwait(false); | |||
await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); | |||
_logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); | |||
} |
@@ -18,7 +18,7 @@ namespace MQTTnet.Server | |||
private readonly MqttRetainedMessagesManager _retainedMessagesManager; | |||
private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; | |||
private readonly MqttClientPendingMessagesQueue _pendingMessagesQueue; | |||
private readonly MqttClientPendingPacketsQueue _pendingPacketsQueue; | |||
private readonly MqttClientSubscriptionsManager _subscriptionsManager; | |||
private readonly MqttClientSessionsManager _sessionsManager; | |||
@@ -49,7 +49,7 @@ namespace MQTTnet.Server | |||
_keepAliveMonitor = new MqttClientKeepAliveMonitor(clientId, () => Stop(MqttClientDisconnectType.NotClean), _logger); | |||
_subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, sessionsManager.Server); | |||
_pendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger); | |||
_pendingPacketsQueue = new MqttClientPendingPacketsQueue(_options, this, _logger); | |||
} | |||
public string ClientId { get; } | |||
@@ -60,7 +60,7 @@ namespace MQTTnet.Server | |||
status.IsConnected = _adapter != null; | |||
status.Endpoint = _adapter?.Endpoint; | |||
status.ProtocolVersion = _adapter?.PacketSerializer?.ProtocolVersion; | |||
status.PendingApplicationMessagesCount = _pendingMessagesQueue.Count; | |||
status.PendingApplicationMessagesCount = _pendingPacketsQueue.Count; | |||
status.LastPacketReceived = _keepAliveMonitor.LastPacketReceived; | |||
status.LastNonKeepAlivePacketReceived = _keepAliveMonitor.LastNonKeepAlivePacketReceived; | |||
} | |||
@@ -80,7 +80,7 @@ namespace MQTTnet.Server | |||
_wasCleanDisconnect = false; | |||
_willMessage = connectPacket.WillMessage; | |||
_pendingMessagesQueue.Start(adapter, _cancellationTokenSource.Token); | |||
_pendingPacketsQueue.Start(adapter, _cancellationTokenSource.Token); | |||
_keepAliveMonitor.Start(connectPacket.KeepAlivePeriod, _cancellationTokenSource.Token); | |||
while (!_cancellationTokenSource.IsCancellationRequested) | |||
@@ -149,13 +149,10 @@ namespace MQTTnet.Server | |||
if (_willMessage != null && !_wasCleanDisconnect) | |||
{ | |||
_sessionsManager.StartDispatchApplicationMessage(this, _willMessage); | |||
_sessionsManager.EnqueueApplicationMessage(this, _willMessage); | |||
} | |||
_willMessage = null; | |||
////_pendingMessagesQueue.WaitForCompletion(); | |||
////_keepAliveMonitor.WaitForCompletion(); | |||
} | |||
finally | |||
{ | |||
@@ -196,7 +193,7 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
_pendingMessagesQueue.Enqueue(publishPacket); | |||
_pendingPacketsQueue.Enqueue(publishPacket); | |||
} | |||
public Task SubscribeAsync(IList<TopicFilter> topicFilters) | |||
@@ -226,12 +223,12 @@ namespace MQTTnet.Server | |||
public void ClearPendingApplicationMessages() | |||
{ | |||
_pendingMessagesQueue.Clear(); | |||
_pendingPacketsQueue.Clear(); | |||
} | |||
public void Dispose() | |||
{ | |||
_pendingMessagesQueue?.Dispose(); | |||
_pendingPacketsQueue?.Dispose(); | |||
_cancellationTokenSource?.Dispose(); | |||
} | |||
@@ -245,7 +242,7 @@ namespace MQTTnet.Server | |||
if (packet is MqttPingReqPacket) | |||
{ | |||
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, new MqttPingRespPacket(), cancellationToken); | |||
return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); | |||
} | |||
if (packet is MqttPubRelPacket pubRelPacket) | |||
@@ -260,7 +257,7 @@ namespace MQTTnet.Server | |||
PacketIdentifier = pubRecPacket.PacketIdentifier | |||
}; | |||
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, responsePacket, cancellationToken); | |||
return adapter.SendPacketAsync(responsePacket, cancellationToken); | |||
} | |||
if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) | |||
@@ -308,7 +305,7 @@ namespace MQTTnet.Server | |||
private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) | |||
{ | |||
var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket); | |||
await adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); | |||
await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); | |||
if (subscribeResult.CloseConnection) | |||
{ | |||
@@ -322,7 +319,7 @@ namespace MQTTnet.Server | |||
private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) | |||
{ | |||
var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); | |||
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, unsubscribeResult, cancellationToken); | |||
return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); | |||
} | |||
private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) | |||
@@ -333,7 +330,7 @@ namespace MQTTnet.Server | |||
{ | |||
case MqttQualityOfServiceLevel.AtMostOnce: | |||
{ | |||
_sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); | |||
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage); | |||
return Task.FromResult(0); | |||
} | |||
case MqttQualityOfServiceLevel.AtLeastOnce: | |||
@@ -353,25 +350,25 @@ namespace MQTTnet.Server | |||
private Task HandleIncomingPublishPacketWithQoS1(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) | |||
{ | |||
_sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); | |||
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage); | |||
var response = new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }; | |||
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); | |||
return adapter.SendPacketAsync(response, cancellationToken); | |||
} | |||
private Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) | |||
{ | |||
// QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) | |||
_sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); | |||
_sessionsManager.EnqueueApplicationMessage(this, applicationMessage); | |||
var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }; | |||
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); | |||
return adapter.SendPacketAsync(response, cancellationToken); | |||
} | |||
private Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) | |||
private static Task HandleIncomingPubRelPacketAsync(IMqttChannelAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) | |||
{ | |||
var response = new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }; | |||
return adapter.SendPacketAsync(_options.DefaultCommunicationTimeout, response, cancellationToken); | |||
return adapter.SendPacketAsync(response, cancellationToken); | |||
} | |||
private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) | |||
@@ -6,27 +6,29 @@ using System.Threading.Tasks; | |||
using MQTTnet.Adapter; | |||
using MQTTnet.Diagnostics; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Internal; | |||
using MQTTnet.Packets; | |||
using MQTTnet.Protocol; | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttClientSessionsManager : IDisposable | |||
public class MqttClientSessionsManager | |||
{ | |||
private readonly BlockingCollection<MqttEnqueuedApplicationMessage> _messageQueue = new BlockingCollection<MqttEnqueuedApplicationMessage>(); | |||
private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>(); | |||
private readonly AsyncLock _sessionPreparationLock = new AsyncLock(); | |||
private readonly CancellationToken _cancellationToken; | |||
private readonly MqttRetainedMessagesManager _retainedMessagesManager; | |||
private readonly IMqttServerOptions _options; | |||
private readonly IMqttNetChildLogger _logger; | |||
public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetChildLogger logger) | |||
public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, CancellationToken cancellationToken, IMqttNetChildLogger logger) | |||
{ | |||
if (logger == null) throw new ArgumentNullException(nameof(logger)); | |||
_logger = logger.CreateChildLogger(nameof(MqttClientSessionsManager)); | |||
_cancellationToken = cancellationToken; | |||
_options = options ?? throw new ArgumentNullException(nameof(options)); | |||
Server = server ?? throw new ArgumentNullException(nameof(server)); | |||
_retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); | |||
@@ -34,7 +36,129 @@ namespace MQTTnet.Server | |||
public MqttServer Server { get; } | |||
public async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) | |||
public void Start() | |||
{ | |||
Task.Factory.StartNew(() => ProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); | |||
} | |||
public Task StopAsync() | |||
{ | |||
foreach (var session in _sessions) | |||
{ | |||
session.Value.Stop(MqttClientDisconnectType.NotClean); | |||
} | |||
_sessions.Clear(); | |||
return Task.FromResult(0); | |||
} | |||
public Task StartSession(IMqttChannelAdapter clientAdapter) | |||
{ | |||
return Task.Run(() => RunSession(clientAdapter, _cancellationToken), _cancellationToken); | |||
} | |||
public Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync() | |||
{ | |||
var result = new List<IMqttClientSessionStatus>(); | |||
foreach (var session in _sessions) | |||
{ | |||
var status = new MqttClientSessionStatus(this, session.Value); | |||
session.Value.FillStatus(status); | |||
result.Add(status); | |||
} | |||
return Task.FromResult((IList<IMqttClientSessionStatus>)result); | |||
} | |||
public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) | |||
{ | |||
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); | |||
_messageQueue.Add(new MqttEnqueuedApplicationMessage(senderClientSession, applicationMessage), _cancellationToken); | |||
} | |||
public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters) | |||
{ | |||
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); | |||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||
if (!_sessions.TryGetValue(clientId, out var session)) | |||
{ | |||
throw new InvalidOperationException($"Client session '{clientId}' is unknown."); | |||
} | |||
return session.SubscribeAsync(topicFilters); | |||
} | |||
public Task UnsubscribeAsync(string clientId, IList<string> topicFilters) | |||
{ | |||
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); | |||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||
if (!_sessions.TryGetValue(clientId, out var session)) | |||
{ | |||
throw new InvalidOperationException($"Client session '{clientId}' is unknown."); | |||
} | |||
return session.UnsubscribeAsync(topicFilters); | |||
} | |||
public void DeleteSession(string clientId) | |||
{ | |||
_sessions.TryRemove(clientId, out _); | |||
_logger.Verbose("Session for client '{0}' deleted.", clientId); | |||
} | |||
private void ProcessQueuedApplicationMessages(CancellationToken cancellationToken) | |||
{ | |||
while (!cancellationToken.IsCancellationRequested) | |||
{ | |||
try | |||
{ | |||
var enqueuedApplicationMessage = _messageQueue.Take(cancellationToken); | |||
var sender = enqueuedApplicationMessage.Sender; | |||
var applicationMessage = enqueuedApplicationMessage.ApplicationMessage; | |||
var interceptorContext = InterceptApplicationMessage(sender, applicationMessage); | |||
if (interceptorContext != null) | |||
{ | |||
if (interceptorContext.CloseConnection) | |||
{ | |||
enqueuedApplicationMessage.Sender.Stop(MqttClientDisconnectType.NotClean); | |||
} | |||
if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) | |||
{ | |||
return; | |||
} | |||
applicationMessage = interceptorContext.ApplicationMessage; | |||
} | |||
Server.OnApplicationMessageReceived(sender?.ClientId, applicationMessage); | |||
if (applicationMessage.Retain) | |||
{ | |||
_retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult(); | |||
} | |||
foreach (var clientSession in _sessions.Values) | |||
{ | |||
clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, applicationMessage); | |||
} | |||
} | |||
catch (TaskCanceledException) | |||
{ | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Unhandled exception while processing queued application message."); | |||
} | |||
} | |||
} | |||
private async Task RunSession(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) | |||
{ | |||
var clientId = string.Empty; | |||
var wasCleanDisconnect = false; | |||
@@ -60,7 +184,7 @@ namespace MQTTnet.Server | |||
var connectReturnCode = ValidateConnection(connectPacket); | |||
if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) | |||
{ | |||
await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout, | |||
await clientAdapter.SendPacketAsync( | |||
new MqttConnAckPacket | |||
{ | |||
ConnectReturnCode = connectReturnCode | |||
@@ -70,15 +194,15 @@ namespace MQTTnet.Server | |||
return; | |||
} | |||
var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); | |||
var result = PrepareClientSession(connectPacket); | |||
var clientSession = result.Session; | |||
await clientAdapter.SendPacketAsync(_options.DefaultCommunicationTimeout, | |||
await clientAdapter.SendPacketAsync( | |||
new MqttConnAckPacket | |||
{ | |||
ConnectReturnCode = connectReturnCode, | |||
IsSessionPresent = result.IsExistingSession | |||
}, | |||
}, | |||
cancellationToken).ConfigureAwait(false); | |||
Server.OnClientConnected(clientId); | |||
@@ -113,73 +237,6 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
public Task StopAsync() | |||
{ | |||
foreach (var session in _sessions) | |||
{ | |||
session.Value.Stop(MqttClientDisconnectType.NotClean); | |||
} | |||
_sessions.Clear(); | |||
return Task.FromResult(0); | |||
} | |||
public Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync() | |||
{ | |||
var result = new List<IMqttClientSessionStatus>(); | |||
foreach (var session in _sessions) | |||
{ | |||
var status = new MqttClientSessionStatus(this, session.Value); | |||
session.Value.FillStatus(status); | |||
result.Add(status); | |||
} | |||
return Task.FromResult((IList<IMqttClientSessionStatus>)result); | |||
} | |||
public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) | |||
{ | |||
Task.Run(() => DispatchApplicationMessageAsync(senderClientSession, applicationMessage)); | |||
} | |||
public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters) | |||
{ | |||
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); | |||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||
if (!_sessions.TryGetValue(clientId, out var session)) | |||
{ | |||
throw new InvalidOperationException($"Client session '{clientId}' is unknown."); | |||
} | |||
return session.SubscribeAsync(topicFilters); | |||
} | |||
public Task UnsubscribeAsync(string clientId, IList<string> topicFilters) | |||
{ | |||
if (clientId == null) throw new ArgumentNullException(nameof(clientId)); | |||
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); | |||
if (!_sessions.TryGetValue(clientId, out var session)) | |||
{ | |||
throw new InvalidOperationException($"Client session '{clientId}' is unknown."); | |||
} | |||
return session.UnsubscribeAsync(topicFilters); | |||
} | |||
public void DeleteSession(string clientId) | |||
{ | |||
_sessions.TryRemove(clientId, out _); | |||
_logger.Verbose("Session for client '{0}' deleted.", clientId); | |||
} | |||
public void Dispose() | |||
{ | |||
_sessionPreparationLock?.Dispose(); | |||
} | |||
private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket) | |||
{ | |||
if (_options.ConnectionValidator == null) | |||
@@ -197,9 +254,9 @@ namespace MQTTnet.Server | |||
return context.ReturnCode; | |||
} | |||
private async Task<GetOrCreateClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket) | |||
private PrepareClientSessionResult PrepareClientSession(MqttConnectPacket connectPacket) | |||
{ | |||
using (await _sessionPreparationLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) | |||
lock (_sessions) | |||
{ | |||
var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); | |||
if (isSessionPresent) | |||
@@ -231,60 +288,19 @@ namespace MQTTnet.Server | |||
_logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); | |||
} | |||
return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; | |||
return new PrepareClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession }; | |||
} | |||
} | |||
private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) | |||
private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) | |||
{ | |||
try | |||
{ | |||
var interceptorContext = InterceptApplicationMessage(senderClientSession, applicationMessage); | |||
if (interceptorContext != null) | |||
{ | |||
if (interceptorContext.CloseConnection) | |||
{ | |||
senderClientSession.Stop(MqttClientDisconnectType.NotClean); | |||
} | |||
if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) | |||
{ | |||
return; | |||
} | |||
applicationMessage = interceptorContext.ApplicationMessage; | |||
} | |||
Server.OnApplicationMessageReceived(senderClientSession?.ClientId, applicationMessage); | |||
if (applicationMessage.Retain) | |||
{ | |||
await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false); | |||
} | |||
foreach (var clientSession in _sessions.Values) | |||
{ | |||
clientSession.EnqueueApplicationMessage(senderClientSession, applicationMessage); | |||
} | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while processing application message"); | |||
} | |||
} | |||
private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) | |||
{ | |||
var interceptorContext = new MqttApplicationMessageInterceptorContext( | |||
senderClientSession?.ClientId, | |||
applicationMessage); | |||
var interceptor = _options.ApplicationMessageInterceptor; | |||
if (interceptor == null) | |||
{ | |||
return interceptorContext; | |||
return null; | |||
} | |||
var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage); | |||
interceptor(interceptorContext); | |||
return interceptorContext; | |||
} | |||
@@ -0,0 +1,15 @@ | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttEnqueuedApplicationMessage | |||
{ | |||
public MqttEnqueuedApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) | |||
{ | |||
Sender = sender; | |||
ApplicationMessage = applicationMessage; | |||
} | |||
public MqttClientSession Sender { get; } | |||
public MqttApplicationMessage ApplicationMessage { get; } | |||
} | |||
} |
@@ -65,7 +65,7 @@ namespace MQTTnet.Server | |||
if (_cancellationTokenSource == null) throw new InvalidOperationException("The server is not started."); | |||
_clientSessionsManager.StartDispatchApplicationMessage(null, applicationMessage); | |||
_clientSessionsManager.EnqueueApplicationMessage(null, applicationMessage); | |||
return Task.FromResult(0); | |||
} | |||
@@ -81,7 +81,8 @@ namespace MQTTnet.Server | |||
_retainedMessagesManager = new MqttRetainedMessagesManager(Options, _logger); | |||
await _retainedMessagesManager.LoadMessagesAsync().ConfigureAwait(false); | |||
_clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _logger); | |||
_clientSessionsManager = new MqttClientSessionsManager(Options, this, _retainedMessagesManager, _cancellationTokenSource.Token, _logger); | |||
_clientSessionsManager.Start(); | |||
foreach (var adapter in _adapters) | |||
{ | |||
@@ -118,8 +119,6 @@ namespace MQTTnet.Server | |||
} | |||
finally | |||
{ | |||
_clientSessionsManager?.Dispose(); | |||
_cancellationTokenSource = null; | |||
_retainedMessagesManager = null; | |||
_clientSessionsManager = null; | |||
@@ -155,9 +154,7 @@ namespace MQTTnet.Server | |||
private void OnClientAccepted(object sender, MqttServerAdapterClientAcceptedEventArgs eventArgs) | |||
{ | |||
eventArgs.SessionTask = Task.Run( | |||
() => _clientSessionsManager.RunSessionAsync(eventArgs.Client, _cancellationTokenSource.Token), | |||
_cancellationTokenSource.Token); | |||
eventArgs.SessionTask = _clientSessionsManager.StartSession(eventArgs.Client); | |||
} | |||
} | |||
} |
@@ -1,6 +1,6 @@ | |||
namespace MQTTnet.Server | |||
{ | |||
public class GetOrCreateClientSessionResult | |||
public class PrepareClientSessionResult | |||
{ | |||
public bool IsExistingSession { get; set; } | |||
@@ -65,7 +65,7 @@ namespace MQTTnet.Benchmarks | |||
for (var i = 0; i < 10000; i++) | |||
{ | |||
_channelAdapter.SendPacketAsync(TimeSpan.FromSeconds(15), _packet, CancellationToken.None).GetAwaiter().GetResult(); | |||
_channelAdapter.SendPacketAsync(_packet, CancellationToken.None).GetAwaiter().GetResult(); | |||
} | |||
_stream.Position = 0; | |||
@@ -1,30 +0,0 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Serializer; | |||
namespace MQTTnet.Core.Tests | |||
{ | |||
[TestClass] | |||
public class ByteReaderTests | |||
{ | |||
[TestMethod] | |||
public void ByteReader_ReadToEnd() | |||
{ | |||
var reader = new ByteReader(85); | |||
Assert.IsTrue(reader.Read()); | |||
Assert.IsFalse(reader.Read()); | |||
Assert.IsTrue(reader.Read()); | |||
Assert.IsFalse(reader.Read()); | |||
Assert.IsTrue(reader.Read()); | |||
Assert.IsFalse(reader.Read()); | |||
Assert.IsTrue(reader.Read()); | |||
Assert.IsFalse(reader.Read()); | |||
} | |||
[TestMethod] | |||
public void ByteReader_ReadPartial() | |||
{ | |||
var reader = new ByteReader(15); | |||
Assert.AreEqual(3, reader.Read(2)); | |||
} | |||
} | |||
} |
@@ -1,51 +0,0 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Serializer; | |||
namespace MQTTnet.Core.Tests | |||
{ | |||
[TestClass] | |||
public class ByteWriterTests | |||
{ | |||
[TestMethod] | |||
public void ByteWriter_WriteMultipleAll() | |||
{ | |||
var b = new ByteWriter(); | |||
Assert.AreEqual(0, b.Value); | |||
b.Write(3, 2); | |||
Assert.AreEqual(3, b.Value); | |||
} | |||
[TestMethod] | |||
public void ByteWriter_WriteMultiplePartial() | |||
{ | |||
var b = new ByteWriter(); | |||
Assert.AreEqual(0, b.Value); | |||
b.Write(255, 2); | |||
Assert.AreEqual(3, b.Value); | |||
} | |||
[TestMethod] | |||
public void ByteWriter_WriteTo0xFF() | |||
{ | |||
var b = new ByteWriter(); | |||
Assert.AreEqual(0, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(1, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(3, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(7, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(15, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(31, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(63, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(127, b.Value); | |||
b.Write(true); | |||
Assert.AreEqual(255, b.Value); | |||
} | |||
} | |||
} |
@@ -14,20 +14,20 @@ namespace MQTTnet.Core.Tests | |||
[TestMethod] | |||
public async Task TimeoutAfter() | |||
{ | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||
} | |||
[ExpectedException(typeof(MqttCommunicationTimedOutException))] | |||
[TestMethod] | |||
public async Task TimeoutAfterWithResult() | |||
{ | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(500), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(100), CancellationToken.None); | |||
} | |||
[TestMethod] | |||
public async Task TimeoutAfterCompleteInTime() | |||
{ | |||
var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); | |||
var result = await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(100), ct).ContinueWith(t => 5, ct), TimeSpan.FromMilliseconds(500), CancellationToken.None); | |||
Assert.AreEqual(5, result); | |||
} | |||
@@ -36,7 +36,7 @@ namespace MQTTnet.Core.Tests | |||
{ | |||
try | |||
{ | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => | |||
{ | |||
var iis = new int[0]; | |||
iis[1] = 0; | |||
@@ -55,7 +55,7 @@ namespace MQTTnet.Core.Tests | |||
{ | |||
try | |||
{ | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Run(() => | |||
await MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Run(() => | |||
{ | |||
var iis = new int[0]; | |||
iis[1] = 0; | |||
@@ -76,7 +76,7 @@ namespace MQTTnet.Core.Tests | |||
var tasks = Enumerable.Range(0, 100000) | |||
.Select(i => | |||
{ | |||
return MQTTnet.Internal.TaskExtensions.TimeoutAfter(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); | |||
return MQTTnet.Internal.TaskExtensions.TimeoutAfterAsync(ct => Task.Delay(TimeSpan.FromMilliseconds(1), ct), TimeSpan.FromMinutes(1), CancellationToken.None); | |||
}); | |||
await Task.WhenAll(tasks); | |||
@@ -11,7 +11,7 @@ namespace MQTTnet.Core.Tests | |||
public class MqttPacketReaderTests | |||
{ | |||
[TestMethod] | |||
[ExpectedException(typeof(MqttCommunicationException))] | |||
[ExpectedException(typeof(MqttCommunicationClosedGracefullyException))] | |||
public void MqttPacketReader_EmptyStream() | |||
{ | |||
MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); | |||
@@ -1,6 +1,5 @@ | |||
using System; | |||
using System.Collections.Concurrent; | |||
using System.Collections.Generic; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Adapter; | |||
@@ -36,7 +35,7 @@ namespace MQTTnet.Core.Tests | |||
return Task.FromResult(0); | |||
} | |||
public Task SendPacketAsync(TimeSpan timeout, MqttBasePacket packet, CancellationToken cancellationToken) | |||
public Task SendPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) | |||
{ | |||
ThrowIfPartnerIsNull(); | |||
@@ -12,16 +12,46 @@ namespace MQTTnet.TestApp.NetCore | |||
{ | |||
public static class PerformanceTest | |||
{ | |||
public static async Task RunAsync() | |||
public static void Run() | |||
{ | |||
Console.WriteLine("Press 'c' for concurrent sends. Otherwise in one batch."); | |||
var concurrent = Console.ReadKey(true).KeyChar == 'c'; | |||
try | |||
{ | |||
var mqttServer = new MqttFactory().CreateMqttServer(); | |||
mqttServer.StartAsync(new MqttServerOptions()).GetAwaiter().GetResult(); | |||
var server = Task.Run(RunServerAsync); | |||
await Task.Delay(1000); | |||
var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10), concurrent)); | |||
var options = new MqttClientOptions | |||
{ | |||
ChannelOptions = new MqttClientTcpOptions | |||
{ | |||
Server = "127.0.0.1" | |||
}, | |||
CleanSession = true | |||
}; | |||
var client = new MqttFactory().CreateMqttClient(); | |||
client.ConnectAsync(options).GetAwaiter().GetResult(); | |||
var message = CreateMessage(); | |||
var stopwatch = new Stopwatch(); | |||
await Task.WhenAll(server, client).ConfigureAwait(false); | |||
for (var i = 0; i < 10; i++) | |||
{ | |||
stopwatch.Restart(); | |||
var sentMessagesCount = 0; | |||
while (stopwatch.ElapsedMilliseconds < 1000) | |||
{ | |||
client.PublishAsync(message).GetAwaiter().GetResult(); | |||
sentMessagesCount++; | |||
} | |||
Console.WriteLine($"Sending {sentMessagesCount} messages per second. #" + (i + 1)); | |||
} | |||
} | |||
catch (Exception exception) | |||
{ | |||
Console.WriteLine(exception); | |||
} | |||
} | |||
private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval, bool concurrent) | |||
@@ -53,29 +83,8 @@ namespace MQTTnet.TestApp.NetCore | |||
} | |||
var message = CreateMessage(); | |||
var messages = new[] { message }; | |||
var stopwatch = Stopwatch.StartNew(); | |||
var sentMessagesCount = 0; | |||
while (stopwatch.ElapsedMilliseconds < 1000) | |||
{ | |||
client.PublishAsync(messages).GetAwaiter().GetResult(); | |||
sentMessagesCount++; | |||
} | |||
Console.WriteLine($"Sending {sentMessagesCount} messages per second. #1"); | |||
sentMessagesCount = 0; | |||
stopwatch.Restart(); | |||
while (stopwatch.ElapsedMilliseconds < 1000) | |||
{ | |||
await client.PublishAsync(messages).ConfigureAwait(false); | |||
sentMessagesCount++; | |||
} | |||
Console.WriteLine($"Sending {sentMessagesCount} messages per second. #2"); | |||
var testMessageCount = 10000; | |||
for (var i = 0; i < testMessageCount; i++) | |||
{ | |||
@@ -142,38 +151,5 @@ namespace MQTTnet.TestApp.NetCore | |||
Interlocked.Increment(ref count); | |||
return Task.Run(() => client.PublishAsync(applicationMessage)); | |||
} | |||
private static async Task RunServerAsync() | |||
{ | |||
try | |||
{ | |||
var mqttServer = new MqttFactory().CreateMqttServer(); | |||
////var msgs = 0; | |||
////var stopwatch = Stopwatch.StartNew(); | |||
////mqttServer.ApplicationMessageReceived += (sender, args) => | |||
////{ | |||
//// msgs++; | |||
//// if (stopwatch.ElapsedMilliseconds > 1000) | |||
//// { | |||
//// Console.WriteLine($"received {msgs}"); | |||
//// msgs = 0; | |||
//// stopwatch.Restart(); | |||
//// } | |||
////}; | |||
await mqttServer.StartAsync(new MqttServerOptions()); | |||
Console.WriteLine("Press any key to exit."); | |||
Console.ReadLine(); | |||
await mqttServer.StopAsync().ConfigureAwait(false); | |||
} | |||
catch (Exception e) | |||
{ | |||
Console.WriteLine(e); | |||
} | |||
Console.ReadLine(); | |||
} | |||
} | |||
} |
@@ -34,7 +34,8 @@ namespace MQTTnet.TestApp.NetCore | |||
} | |||
else if (pressedKey.KeyChar == '3') | |||
{ | |||
Task.Run(PerformanceTest.RunAsync); | |||
PerformanceTest.Run(); | |||
return; | |||
} | |||
else if (pressedKey.KeyChar == '4') | |||
{ | |||