diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index 73588ef..218fcb9 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -97,35 +97,29 @@ namespace MQTTnet.Adapter MqttBasePacket packet = null; await ExecuteAndWrapExceptionAsync(async () => { - ReceivedMqttPacket receivedMqttPacket = null; - try - { - if (timeout > TimeSpan.Zero) - { - receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); - } - else - { - receivedMqttPacket = await ReceiveAsync(_channel, cancellationToken).ConfigureAwait(false); - } - - if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) - { - throw new TaskCanceledException(); - } + ReceivedMqttPacket receivedMqttPacket; - packet = PacketSerializer.Deserialize(receivedMqttPacket); - if (packet == null) - { - throw new MqttProtocolViolationException("Received malformed packet."); - } + if (timeout > TimeSpan.Zero) + { + receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); + } + else + { + receivedMqttPacket = await ReceiveAsync(_channel, cancellationToken).ConfigureAwait(false); + } - _logger.Verbose("RX <<< {0}", packet); + if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) + { + return; } - finally + + packet = PacketSerializer.Deserialize(receivedMqttPacket); + if (packet == null) { - receivedMqttPacket?.Dispose(); + throw new MqttProtocolViolationException("Received malformed packet."); } + + _logger.Verbose("RX <<< {0}", packet); }).ConfigureAwait(false); return packet; @@ -134,7 +128,11 @@ namespace MQTTnet.Adapter private async Task ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) { var fixedHeader = await MqttPacketReader.ReadFixedHeaderAsync(channel, cancellationToken).ConfigureAwait(false); - + if (fixedHeader == null) + { + return null; + } + try { ReadingPacketStarted?.Invoke(this, EventArgs.Empty); @@ -144,31 +142,28 @@ namespace MQTTnet.Adapter return new ReceivedMqttPacket(fixedHeader.Flags, null); } - var body = new MemoryStream(fixedHeader.RemainingLength); + var body = new byte[fixedHeader.RemainingLength]; + var bodyOffset = 0; + var chunkSize = Math.Min(ReadBufferSize, fixedHeader.RemainingLength); - var buffer = new byte[Math.Min(ReadBufferSize, fixedHeader.RemainingLength)]; - while (body.Length < fixedHeader.RemainingLength) + do { - var bytesLeft = fixedHeader.RemainingLength - (int)body.Length; - if (bytesLeft > buffer.Length) + var bytesLeft = body.Length - bodyOffset; + if (chunkSize > bytesLeft) { - bytesLeft = buffer.Length; + chunkSize = bytesLeft; } - var readBytes = await channel.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); + var readBytes = await channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken) .ConfigureAwait(false); if (readBytes <= 0) { ExceptionHelper.ThrowGracefulSocketClose(); } - // Here is no need to await because internally only an array is used and no real I/O operation is made. - // Using async here will only generate overhead. - body.Write(buffer, 0, readBytes); - } - - body.Seek(0L, SeekOrigin.Begin); + bodyOffset += readBytes; + } while (bodyOffset < body.Length); - return new ReceivedMqttPacket(fixedHeader.Flags, body); + return new ReceivedMqttPacket(fixedHeader.Flags, new MqttPacketBodyReader(body)); } finally { diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs b/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs index 1f1c59a..3a92d4d 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/ReceivedMqttPacket.cs @@ -1,11 +1,10 @@ -using System; -using System.IO; +using MQTTnet.Serializer; namespace MQTTnet.Adapter { - public class ReceivedMqttPacket : IDisposable + public class ReceivedMqttPacket { - public ReceivedMqttPacket(byte fixedHeader, MemoryStream body) + public ReceivedMqttPacket(byte fixedHeader, MqttPacketBodyReader body) { FixedHeader = fixedHeader; Body = body; @@ -13,11 +12,6 @@ namespace MQTTnet.Adapter public byte FixedHeader { get; } - public MemoryStream Body { get; } - - public void Dispose() - { - Body?.Dispose(); - } + public MqttPacketBodyReader Body { get; } } } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs index 44bc5b4..4569be3 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs @@ -9,5 +9,10 @@ public int BufferSize { get; set; } = 4096; public MqttClientTlsOptions TlsOptions { get; set; } = new MqttClientTlsOptions(); + + public override string ToString() + { + return Server + ":" + this.GetPort(); + } } } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs index 9718298..a4dd0d5 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClientWebSocketOptions.cs @@ -13,8 +13,11 @@ namespace MQTTnet.Client public CookieContainer CookieContainer { get; set; } - public int BufferSize { get; set; } = 4096; - public MqttClientTlsOptions TlsOptions { get; set; } = new MqttClientTlsOptions(); + + public override string ToString() + { + return Uri; + } } } diff --git a/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs b/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs index 21d0584..e48b8ed 100644 --- a/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Internal/TestMqttChannel.cs @@ -3,7 +3,7 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; -namespace MQTTnet.Core.Internal +namespace MQTTnet.Internal { public class TestMqttChannel : IMqttChannel { diff --git a/Frameworks/MQTTnet.NetStandard/MqttFactory.cs b/Frameworks/MQTTnet.NetStandard/MqttFactory.cs index 0ba3d2f..00438fa 100644 --- a/Frameworks/MQTTnet.NetStandard/MqttFactory.cs +++ b/Frameworks/MQTTnet.NetStandard/MqttFactory.cs @@ -4,7 +4,6 @@ using MQTTnet.Adapter; using MQTTnet.Client; using MQTTnet.Diagnostics; using MQTTnet.Implementations; -using MQTTnet.ManagedClient; using MQTTnet.Server; namespace MQTTnet @@ -23,18 +22,6 @@ namespace MQTTnet return new MqttClient(new MqttClientAdapterFactory(), logger); } - public IManagedMqttClient CreateManagedMqttClient() - { - return new ManagedMqttClient(CreateMqttClient(), new MqttNetLogger().CreateChildLogger()); - } - - public IManagedMqttClient CreateManagedMqttClient(IMqttNetLogger logger) - { - if (logger == null) throw new ArgumentNullException(nameof(logger)); - - return new ManagedMqttClient(CreateMqttClient(), logger.CreateChildLogger()); - } - public IMqttServer CreateMqttServer() { var logger = new MqttNetLogger(); diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/Extensions.cs b/Frameworks/MQTTnet.NetStandard/Serializer/Extensions.cs new file mode 100644 index 0000000..1a14de4 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Serializer/Extensions.cs @@ -0,0 +1,20 @@ +using System; + +namespace MQTTnet.Serializer +{ + public static class Extensions + { + public static byte[] ToArray(this ArraySegment source) + { + if (source.Array == null) + { + return null; + } + + var buffer = new byte[source.Count]; + Buffer.BlockCopy(source.Array, source.Offset, buffer, 0, buffer.Length); + + return buffer; + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttFixedHeader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttFixedHeader.cs index a8c2015..d87f63d 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttFixedHeader.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttFixedHeader.cs @@ -1,6 +1,6 @@ namespace MQTTnet.Serializer { - public struct MqttFixedHeader + public class MqttFixedHeader { public MqttFixedHeader(byte flags, int remainingLength) { @@ -10,6 +10,6 @@ public byte Flags { get; } - public int RemainingLength { get; } + public int RemainingLength { get; set; } } } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketBodyReader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketBodyReader.cs new file mode 100644 index 0000000..04c0a22 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketBodyReader.cs @@ -0,0 +1,54 @@ +using System; +using System.Text; + +namespace MQTTnet.Serializer +{ + public class MqttPacketBodyReader + { + private readonly byte[] _buffer; + private int _offset; + + public MqttPacketBodyReader(byte[] buffer) + { + _buffer = buffer; + } + + public int Length => _buffer.Length - _offset; + + public bool EndOfStream => _offset == _buffer.Length; + + public byte ReadByte() + { + return _buffer[_offset++]; + } + + public ArraySegment ReadRemainingData() + { + return new ArraySegment(_buffer, _offset, _buffer.Length - _offset); + } + + public ushort ReadUInt16() + { + var msb = _buffer[_offset++]; + var lsb = _buffer[_offset++]; + + return (ushort)(msb << 8 | lsb); + } + + public ArraySegment ReadWithLengthPrefix() + { + var length = ReadUInt16(); + + var result = new ArraySegment(_buffer, _offset, length); + _offset += length; + + return result; + } + + public string ReadStringWithLengthPrefix() + { + var buffer = ReadWithLengthPrefix(); + return Encoding.UTF8.GetString(buffer.Array, buffer.Offset, buffer.Count); + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs index 110407e..7ed918f 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs @@ -1,7 +1,4 @@ -using System; -using System.IO; -using System.Text; -using System.Threading; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; using MQTTnet.Exceptions; @@ -23,6 +20,11 @@ namespace MQTTnet.Serializer var bytesRead = await channel.ReadAsync(buffer, 0, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); if (bytesRead <= 0) { + if (cancellationToken.IsCancellationRequested) + { + return null; + } + ExceptionHelper.ThrowGracefulSocketClose(); } @@ -35,78 +37,20 @@ namespace MQTTnet.Serializer return new MqttFixedHeader(buffer[0], 0); } - var bodyLength = await ReadBodyLengthAsync(channel, buffer[1], cancellationToken); + var bodyLength = await ReadBodyLengthAsync(channel, buffer[1], cancellationToken).ConfigureAwait(false); return new MqttFixedHeader(buffer[0], bodyLength); } - public static ushort ReadUInt16(this Stream stream) - { - var buffer = stream.ReadBytes(2); - - var temp = buffer[0]; - buffer[0] = buffer[1]; - buffer[1] = temp; - - return BitConverter.ToUInt16(buffer, 0); - } - - public static string ReadStringWithLengthPrefix(this Stream stream) - { - var buffer = stream.ReadWithLengthPrefix(); - if (buffer.Length == 0) - { - return string.Empty; - } - - return Encoding.UTF8.GetString(buffer, 0, buffer.Length); - } - - public static byte[] ReadWithLengthPrefix(this Stream stream) - { - var length = stream.ReadUInt16(); - if (length == 0) - { - return new byte[0]; - } - - return stream.ReadBytes(length); - } - - public static byte[] ReadRemainingData(this Stream stream) - { - return stream.ReadBytes((int)(stream.Length - stream.Position)); - } - - private static byte[] ReadBytes(this Stream stream, int count) - { - var buffer = new byte[count]; - var readBytes = stream.Read(buffer, 0, count); - - if (readBytes != count) - { - throw new InvalidOperationException($"Unable to read {count} bytes from the stream."); - } - - return buffer; - } - private static async Task ReadBodyLengthAsync(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. var multiplier = 128; var value = initialEncodedByte & 127; int encodedByte = initialEncodedByte; - var buffer = new byte[1]; - + while ((encodedByte & 128) != 0) { - var readCount = await channel.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); - if (readCount <= 0) - { - ExceptionHelper.ThrowGracefulSocketClose(); - } - - encodedByte = buffer[0]; + encodedByte = await ReadByteAsync(channel, cancellationToken).ConfigureAwait(false); value += (byte)(encodedByte & 127) * multiplier; if (multiplier > 128 * 128 * 128) @@ -119,5 +63,17 @@ namespace MQTTnet.Serializer return value; } + + private static async Task ReadByteAsync(IMqttChannel channel, CancellationToken cancellationToken) + { + var buffer = new byte[1]; + var readCount = await channel.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); + if (readCount <= 0) + { + ExceptionHelper.ThrowGracefulSocketClose(); + } + + return buffer[0]; + } } } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 32001db..2acbfbb 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -104,7 +104,7 @@ namespace MQTTnet.Serializer } } - private static MqttBasePacket DeserializeUnsubAck(Stream body) + private static MqttBasePacket DeserializeUnsubAck(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -114,7 +114,7 @@ namespace MQTTnet.Serializer }; } - private static MqttBasePacket DeserializePubComp(Stream body) + private static MqttBasePacket DeserializePubComp(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -124,7 +124,7 @@ namespace MQTTnet.Serializer }; } - private static MqttBasePacket DeserializePubRel(Stream body) + private static MqttBasePacket DeserializePubRel(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -134,7 +134,7 @@ namespace MQTTnet.Serializer }; } - private static MqttBasePacket DeserializePubRec(Stream body) + private static MqttBasePacket DeserializePubRec(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -144,7 +144,7 @@ namespace MQTTnet.Serializer }; } - private static MqttBasePacket DeserializePubAck(Stream body) + private static MqttBasePacket DeserializePubAck(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -154,7 +154,7 @@ namespace MQTTnet.Serializer }; } - private static MqttBasePacket DeserializeUnsubscribe(Stream body) + private static MqttBasePacket DeserializeUnsubscribe(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -163,7 +163,7 @@ namespace MQTTnet.Serializer PacketIdentifier = body.ReadUInt16(), }; - while (body.Position != body.Length) + while (!body.EndOfStream) { packet.TopicFilters.Add(body.ReadStringWithLengthPrefix()); } @@ -171,7 +171,7 @@ namespace MQTTnet.Serializer return packet; } - private static MqttBasePacket DeserializeSubscribe(Stream body) + private static MqttBasePacket DeserializeSubscribe(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -180,7 +180,7 @@ namespace MQTTnet.Serializer PacketIdentifier = body.ReadUInt16() }; - while (body.Position != body.Length) + while (!body.EndOfStream) { packet.TopicFilters.Add(new TopicFilter( body.ReadStringWithLengthPrefix(), @@ -213,7 +213,7 @@ namespace MQTTnet.Serializer PacketIdentifier = packetIdentifier, Retain = retain, Topic = topic, - Payload = body.ReadRemainingData(), + Payload = body.ReadRemainingData().ToArray(), QualityOfServiceLevel = qualityOfServiceLevel, Dup = dup }; @@ -221,7 +221,7 @@ namespace MQTTnet.Serializer return packet; } - private static MqttBasePacket DeserializeConnect(Stream body) + private static MqttBasePacket DeserializeConnect(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -279,7 +279,7 @@ namespace MQTTnet.Serializer packet.WillMessage = new MqttApplicationMessage { Topic = body.ReadStringWithLengthPrefix(), - Payload = body.ReadWithLengthPrefix(), + Payload = body.ReadWithLengthPrefix().ToArray(), QualityOfServiceLevel = (MqttQualityOfServiceLevel)willQoS, Retain = willRetain }; @@ -299,7 +299,7 @@ namespace MQTTnet.Serializer return packet; } - private static MqttBasePacket DeserializeSubAck(Stream body) + private static MqttBasePacket DeserializeSubAck(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -308,7 +308,7 @@ namespace MQTTnet.Serializer PacketIdentifier = body.ReadUInt16() }; - while (body.Position != body.Length) + while (!body.EndOfStream) { packet.SubscribeReturnCodes.Add((MqttSubscribeReturnCode)body.ReadByte()); } @@ -316,7 +316,7 @@ namespace MQTTnet.Serializer return packet; } - private MqttBasePacket DeserializeConnAck(Stream body) + private MqttBasePacket DeserializeConnAck(MqttPacketBodyReader body) { ThrowIfBodyIsEmpty(body); @@ -614,7 +614,7 @@ namespace MQTTnet.Serializer return MqttPacketWriter.BuildFixedHeader(type); } - private static void ThrowIfBodyIsEmpty(Stream body) + private static void ThrowIfBodyIsEmpty(MqttPacketBodyReader body) { if (body == null || body.Length == 0) { diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs index 4fd7266..cb1043b 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs @@ -86,8 +86,11 @@ namespace MQTTnet.Server while (!_cancellationTokenSource.IsCancellationRequested) { var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false); - _keepAliveMonitor.PacketReceived(packet); - await ProcessReceivedPacketAsync(adapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false); + if (packet != null) + { + _keepAliveMonitor.PacketReceived(packet); + await ProcessReceivedPacketAsync(adapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false); + } } } catch (OperationCanceledException) @@ -351,7 +354,7 @@ namespace MQTTnet.Server private Task HandleIncomingPublishPacketWithQoS2(IMqttChannelAdapter adapter, MqttApplicationMessage applicationMessage, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - // QoS 2 is implement as method "B" [4.3.3 QoS 2: Exactly once delivery] + // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) _sessionsManager.StartDispatchApplicationMessage(this, applicationMessage); var response = new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }; diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs index 692e9c6..dadf32a 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs @@ -41,11 +41,15 @@ namespace MQTTnet.Server try { - if (!(await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken) - .ConfigureAwait(false) is MqttConnectPacket connectPacket)) + var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); + if (firstPacket == null) { - throw new MqttProtocolViolationException( - "The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1]."); + return; + } + + if (!(firstPacket is MqttConnectPacket connectPacket)) + { + throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1]."); } clientId = connectPacket.ClientId;