diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs index a64b445..47dfb45 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs @@ -10,6 +10,7 @@ using MQTTnet.Core.Adapter; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Serializer; using MQTTnet.Core.Server; +using MQTTnet.Core.Channel; namespace MQTTnet.Implementations { @@ -86,7 +87,7 @@ namespace MQTTnet.Implementations try { var clientSocket = await Task.Factory.FromAsync(_defaultEndpointSocket.BeginAccept, _defaultEndpointSocket.EndAccept, null).ConfigureAwait(false); - var clientAdapter = new MqttChannelCommunicationAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer()); + var clientAdapter = new MqttChannelCommunicationAdapter(new BufferedCommunicationChannel(new MqttTcpChannel(clientSocket, null)), new MqttPacketSerializer()); ClientConnected?.Invoke(this, new MqttClientConnectedEventArgs(clientSocket.RemoteEndPoint.ToString(), clientAdapter)); } catch (Exception exception) when (!(exception is ObjectDisposedException)) diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs index 5626256..66fc293 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs @@ -93,17 +93,15 @@ namespace MQTTnet.Implementations } } - public async Task ReadAsync(byte[] buffer) + public async Task> ReadAsync(int length, byte[] buffer) { - if (buffer == null) throw new ArgumentNullException(nameof(buffer)); - try { var totalBytes = 0; do { - var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false); + var read = await _dataStream.ReadAsync(buffer, totalBytes, length - totalBytes).ConfigureAwait(false); if (read == 0) { throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); @@ -111,7 +109,8 @@ namespace MQTTnet.Implementations totalBytes += read; } - while (totalBytes < buffer.Length); + while (totalBytes < length); + return new ArraySegment(buffer, 0, length); } catch (SocketException exception) { @@ -143,5 +142,10 @@ namespace MQTTnet.Implementations return certificates; } + + public int Peek() + { + return _socket.Available; + } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs index b90b02a..94f1330 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs @@ -11,9 +11,6 @@ namespace MQTTnet.Implementations public sealed class MqttWebSocketChannel : IMqttCommunicationChannel, IDisposable { private ClientWebSocket _webSocket = new ClientWebSocket(); - private const int BufferSize = 4096; - private const int BufferAmplifier = 20; - private readonly byte[] WebSocketBuffer = new byte[BufferSize * BufferAmplifier]; private int WebSocketBufferSize; private int WebSocketBufferOffset; @@ -42,50 +39,39 @@ namespace MQTTnet.Implementations _webSocket?.Dispose(); } - public Task ReadAsync(byte[] buffer) + public async Task> ReadAsync(int length, byte[] buffer) { - return Task.WhenAll(ReadToBufferAsync(buffer)); + await ReadToBufferAsync(length, buffer).ConfigureAwait(false); + + var result = new ArraySegment(buffer, WebSocketBufferOffset, length); + WebSocketBufferSize -= length; + WebSocketBufferOffset += length; + + return result; } - private async Task ReadToBufferAsync(byte[] buffer) + private async Task ReadToBufferAsync(int length, byte[] buffer) { - var temporaryBuffer = new byte[BufferSize]; - var offset = 0; + if (WebSocketBufferSize > 0) + { + return; + } - while (_webSocket.State == WebSocketState.Open) + var offset = 0; + while (_webSocket.State == WebSocketState.Open && WebSocketBufferSize < length) { - if (WebSocketBufferSize == 0) + WebSocketReceiveResult response; + do { - WebSocketBufferOffset = 0; - - WebSocketReceiveResult response; - do - { - response = await _webSocket.ReceiveAsync(new ArraySegment(temporaryBuffer), CancellationToken.None).ConfigureAwait(false); - - temporaryBuffer.CopyTo(WebSocketBuffer, offset); - offset += response.Count; - temporaryBuffer = new byte[BufferSize]; - } while (!response.EndOfMessage); + response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, offset, buffer.Length - offset), CancellationToken.None).ConfigureAwait(false); + offset += response.Count; + } while (!response.EndOfMessage); - WebSocketBufferSize = response.Count; - if (response.MessageType == WebSocketMessageType.Close) - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); - } - - Buffer.BlockCopy(WebSocketBuffer, 0, buffer, 0, buffer.Length); - WebSocketBufferSize -= buffer.Length; - WebSocketBufferOffset += buffer.Length; - } - else + WebSocketBufferSize = response.Count; + if (response.MessageType == WebSocketMessageType.Close) { - Buffer.BlockCopy(WebSocketBuffer, WebSocketBufferOffset, buffer, 0, buffer.Length); - WebSocketBufferSize -= buffer.Length; - WebSocketBufferOffset += buffer.Length; + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); } - - return; } } @@ -105,5 +91,10 @@ namespace MQTTnet.Implementations throw new MqttCommunicationException(exception); } } + + public int Peek() + { + return WebSocketBufferSize; + } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs b/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs index 8075f92..b3e0080 100644 --- a/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs +++ b/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs @@ -22,7 +22,7 @@ namespace MQTTnet { case MqttConnectionType.Tcp: case MqttConnectionType.Tls: - return new MqttTcpChannel(); + return new BufferedCommunicationChannel( new MqttTcpChannel() ); case MqttConnectionType.Ws: case MqttConnectionType.Wss: return new MqttWebSocketChannel(); diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index 5b20cf1..87c89b1 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -92,17 +92,15 @@ namespace MQTTnet.Implementations } } - public async Task ReadAsync(byte[] buffer) + public async Task> ReadAsync(int length, byte[] buffer) { - if (buffer == null) throw new ArgumentNullException(nameof(buffer)); - try { var totalBytes = 0; do { - var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false); + var read = await _dataStream.ReadAsync(buffer, totalBytes, length - totalBytes).ConfigureAwait(false); if (read == 0) { throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); @@ -110,7 +108,8 @@ namespace MQTTnet.Implementations totalBytes += read; } - while (totalBytes < buffer.Length); + while (totalBytes < length); + return new ArraySegment(buffer, 0, length); } catch (SocketException exception) { @@ -142,5 +141,10 @@ namespace MQTTnet.Implementations return certificates; } + + public int Peek() + { + return _socket.Available; + } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs index b90b02a..2aa31b4 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs @@ -11,9 +11,6 @@ namespace MQTTnet.Implementations public sealed class MqttWebSocketChannel : IMqttCommunicationChannel, IDisposable { private ClientWebSocket _webSocket = new ClientWebSocket(); - private const int BufferSize = 4096; - private const int BufferAmplifier = 20; - private readonly byte[] WebSocketBuffer = new byte[BufferSize * BufferAmplifier]; private int WebSocketBufferSize; private int WebSocketBufferOffset; @@ -42,50 +39,39 @@ namespace MQTTnet.Implementations _webSocket?.Dispose(); } - public Task ReadAsync(byte[] buffer) + public async Task> ReadAsync(int length, byte[] buffer) { - return Task.WhenAll(ReadToBufferAsync(buffer)); + await ReadToBufferAsync(length, buffer).ConfigureAwait(false); + + var result = new ArraySegment(buffer, WebSocketBufferOffset, length); + WebSocketBufferSize -= length; + WebSocketBufferOffset += length; + + return result; } - private async Task ReadToBufferAsync(byte[] buffer) + private async Task ReadToBufferAsync(int length, byte[] buffer) { - var temporaryBuffer = new byte[BufferSize]; - var offset = 0; + if (WebSocketBufferSize > 0) + { + return; + } - while (_webSocket.State == WebSocketState.Open) + var offset = 0; + while (_webSocket.State == WebSocketState.Open && WebSocketBufferSize < length) { - if (WebSocketBufferSize == 0) + WebSocketReceiveResult response; + do { - WebSocketBufferOffset = 0; - - WebSocketReceiveResult response; - do - { - response = await _webSocket.ReceiveAsync(new ArraySegment(temporaryBuffer), CancellationToken.None).ConfigureAwait(false); + response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, offset, buffer.Length - offset), CancellationToken.None).ConfigureAwait(false); + offset += response.Count; + } while (!response.EndOfMessage); - temporaryBuffer.CopyTo(WebSocketBuffer, offset); - offset += response.Count; - temporaryBuffer = new byte[BufferSize]; - } while (!response.EndOfMessage); - - WebSocketBufferSize = response.Count; - if (response.MessageType == WebSocketMessageType.Close) - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); - } - - Buffer.BlockCopy(WebSocketBuffer, 0, buffer, 0, buffer.Length); - WebSocketBufferSize -= buffer.Length; - WebSocketBufferOffset += buffer.Length; - } - else + WebSocketBufferSize = response.Count; + if (response.MessageType == WebSocketMessageType.Close) { - Buffer.BlockCopy(WebSocketBuffer, WebSocketBufferOffset, buffer, 0, buffer.Length); - WebSocketBufferSize -= buffer.Length; - WebSocketBufferOffset += buffer.Length; + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); } - - return; } } @@ -105,5 +91,10 @@ namespace MQTTnet.Implementations throw new MqttCommunicationException(exception); } } + + public int Peek() + { + return WebSocketBufferSize; + } } } \ No newline at end of file diff --git a/MQTTnet.Core/Channel/BufferConstants.cs b/MQTTnet.Core/Channel/BufferConstants.cs new file mode 100644 index 0000000..3b441e0 --- /dev/null +++ b/MQTTnet.Core/Channel/BufferConstants.cs @@ -0,0 +1,7 @@ +namespace MQTTnet.Core.Channel +{ + public static class BufferConstants + { + public const int Size = 4096 * 20; + } +} diff --git a/MQTTnet.Core/Channel/BufferedCommunicationChannel.cs b/MQTTnet.Core/Channel/BufferedCommunicationChannel.cs new file mode 100644 index 0000000..34ee586 --- /dev/null +++ b/MQTTnet.Core/Channel/BufferedCommunicationChannel.cs @@ -0,0 +1,77 @@ +using System.Threading.Tasks; +using MQTTnet.Core.Client; +using System; + +namespace MQTTnet.Core.Channel +{ + public class BufferedCommunicationChannel : IMqttCommunicationChannel + { + private IMqttCommunicationChannel _inner { get; } + private int _bufferSize = 0; + private int _bufferOffset = 0; + + public BufferedCommunicationChannel(IMqttCommunicationChannel inner) + { + _inner = inner; + } + + public Task ConnectAsync(MqttClientOptions options) + { + return _inner.ConnectAsync(options); + } + + public Task DisconnectAsync() + { + return _inner.DisconnectAsync(); + } + + public int Peek() + { + return _inner.Peek(); + } + + public async Task> ReadAsync(int length, byte[] buffer) + { + //read from buffer + if (_bufferSize > 0) + { + return ReadFomBuffer(length, buffer); + } + + var available = _inner.Peek(); + // if there are less or equal bytes available then requested then just read em + if (available <= length) + { + return await _inner.ReadAsync(length, buffer); + } + + //if more bytes are available than requested do buffer them to reduce calls to network buffers + await WriteToBuffer(available, buffer).ConfigureAwait(false); + return ReadFomBuffer(length, buffer); + } + + private async Task WriteToBuffer(int available, byte[] buffer) + { + await _inner.ReadAsync(available, buffer).ConfigureAwait(false); + _bufferSize = available; + _bufferOffset = 0; + } + + private ArraySegment ReadFomBuffer(int length, byte[] buffer) + { + var result = new ArraySegment(buffer, _bufferOffset, length); + _bufferSize -= length; + _bufferOffset += length; + + if (_bufferSize < 0) + { + } + return result; + } + + public Task WriteAsync(byte[] buffer) + { + return _inner.WriteAsync(buffer); + } + } +} diff --git a/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs b/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs index a1ec890..0f6ea4b 100644 --- a/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs +++ b/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs @@ -1,5 +1,6 @@ using System.Threading.Tasks; using MQTTnet.Core.Client; +using System; namespace MQTTnet.Core.Channel { @@ -11,6 +12,11 @@ namespace MQTTnet.Core.Channel Task WriteAsync(byte[] buffer); - Task ReadAsync(byte[] buffer); + /// + /// get the currently available number of bytes without reading them + /// + int Peek(); + + Task> ReadAsync(int length, byte[] buffer); } } diff --git a/MQTTnet.Core/Serializer/MqttPacketReader.cs b/MQTTnet.Core/Serializer/MqttPacketReader.cs index af27858..38d3c7c 100644 --- a/MQTTnet.Core/Serializer/MqttPacketReader.cs +++ b/MQTTnet.Core/Serializer/MqttPacketReader.cs @@ -49,13 +49,13 @@ namespace MQTTnet.Core.Serializer return ReadBytes(_header.BodyLength - (int)BaseStream.Position); } - public static async Task ReadHeaderFromSourceAsync(IMqttCommunicationChannel source) + public static async Task ReadHeaderFromSourceAsync(IMqttCommunicationChannel source, byte[] buffer) { - var fixedHeader = await ReadStreamByteAsync(source).ConfigureAwait(false); + var fixedHeader = await ReadStreamByteAsync(source, buffer).ConfigureAwait(false); var byteReader = new ByteReader(fixedHeader); byteReader.Read(4); var controlPacketType = (MqttControlPacketType)byteReader.Read(4); - var bodyLength = await ReadBodyLengthFromSourceAsync(source).ConfigureAwait(false); + var bodyLength = await ReadBodyLengthFromSourceAsync(source, buffer).ConfigureAwait(false); return new MqttPacketHeader() { @@ -65,18 +65,17 @@ namespace MQTTnet.Core.Serializer }; } - private static async Task ReadStreamByteAsync(IMqttCommunicationChannel source) + private static async Task ReadStreamByteAsync(IMqttCommunicationChannel source, byte[] readBuffer) { - var buffer = new byte[1]; - await ReadFromSourceAsync(source, buffer).ConfigureAwait(false); - return buffer[0]; + var result = await ReadFromSourceAsync(source, 1, readBuffer).ConfigureAwait(false); + return result.Array[result.Offset]; } - public static async Task ReadFromSourceAsync(IMqttCommunicationChannel source, byte[] buffer) + public static async Task> ReadFromSourceAsync(IMqttCommunicationChannel source, int length, byte[] buffer) { try { - await source.ReadAsync(buffer); + return await source.ReadAsync(length, buffer); } catch (Exception exception) { @@ -84,7 +83,7 @@ namespace MQTTnet.Core.Serializer } } - private static async Task ReadBodyLengthFromSourceAsync(IMqttCommunicationChannel source) + private static async Task ReadBodyLengthFromSourceAsync(IMqttCommunicationChannel source, byte[] buffer) { // Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. var multiplier = 1; @@ -92,7 +91,7 @@ namespace MQTTnet.Core.Serializer byte encodedByte; do { - encodedByte = await ReadStreamByteAsync(source).ConfigureAwait(false); + encodedByte = await ReadStreamByteAsync(source, buffer).ConfigureAwait(false); value += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128 * 128 * 128) diff --git a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs index 88a7cfe..064b79b 100644 --- a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs +++ b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs @@ -17,6 +17,7 @@ namespace MQTTnet.Core.Serializer private static byte[] ProtocolVersionV310Name { get; } = Encoding.UTF8.GetBytes("MQIs"); public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; + private byte[] _readBuffer = new byte[BufferConstants.Size]; public async Task SerializeAsync(MqttBasePacket packet, IMqttCommunicationChannel destination) { @@ -115,17 +116,13 @@ namespace MQTTnet.Core.Serializer public async Task DeserializeAsync(IMqttCommunicationChannel source) { if (source == null) throw new ArgumentNullException(nameof(source)); - - var header = await MqttPacketReader.ReadHeaderFromSourceAsync(source).ConfigureAwait(false); - var body = new byte[header.BodyLength]; - if (header.BodyLength > 0) - { - await MqttPacketReader.ReadFromSourceAsync(source, body).ConfigureAwait(false); - } + var header = await MqttPacketReader.ReadHeaderFromSourceAsync(source, _readBuffer).ConfigureAwait(false); - using (var mqttPacketReader = new MqttPacketReader(new MemoryStream(body), header)) - { + var body = await GetBody(source, header).ConfigureAwait(false); + + using (var mqttPacketReader = new MqttPacketReader(body, header)) + { switch (header.ControlPacketType) { case MqttControlPacketType.Connect: @@ -221,6 +218,17 @@ namespace MQTTnet.Core.Serializer } } + private async Task GetBody(IMqttCommunicationChannel source, MqttPacketHeader header) + { + if (header.BodyLength > 0) + { + var segment = await MqttPacketReader.ReadFromSourceAsync(source, header.BodyLength, _readBuffer).ConfigureAwait(false); + return new MemoryStream(segment.Array, segment.Offset, segment.Count); + } + + return new MemoryStream(); + } + private static MqttBasePacket DeserializeUnsubscribe(MqttPacketReader reader) { var packet = new MqttUnsubscribePacket diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index fad1f4e..6a18f50 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -418,15 +418,21 @@ namespace MQTTnet.Core.Tests return _stream.WriteAsync(buffer, 0, buffer.Length); } - public Task ReadAsync(byte[] buffer) + public async Task> ReadAsync(int length, byte[] buffer) { - return _stream.ReadAsync(buffer, 0, buffer.Length); + await _stream.ReadAsync(buffer, 0, length); + return new ArraySegment(buffer, 0, length); } public byte[] ToArray() { return _stream.ToArray(); } + + public int Peek() + { + return (int)_stream.Length - (int)_stream.Position; + } } private static void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) diff --git a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs index 21e791b..32750d1 100644 --- a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs @@ -29,7 +29,8 @@ namespace MQTTnet.TestApp.NetFramework { Server = "localhost", ClientId = "XYZ", - CleanSession = true + CleanSession = true, + DefaultCommunicationTimeout = TimeSpan.FromMinutes(10) }; var client = new MqttClientFactory().CreateMqttClient(options); @@ -128,7 +129,8 @@ namespace MQTTnet.TestApp.NetFramework } return MqttConnectReturnCode.ConnectionAccepted; - } + }, + DefaultCommunicationTimeout = TimeSpan.FromMinutes(10) }; var mqttServer = new MqttServerFactory().CreateMqttServer(options); diff --git a/Tests/MQTTnet.TestApp.NetFramework/Program.cs b/Tests/MQTTnet.TestApp.NetFramework/Program.cs index d27f423..1111653 100644 --- a/Tests/MQTTnet.TestApp.NetFramework/Program.cs +++ b/Tests/MQTTnet.TestApp.NetFramework/Program.cs @@ -20,18 +20,18 @@ namespace MQTTnet.TestApp.NetFramework Console.WriteLine("1 = Start client"); Console.WriteLine("2 = Start server"); Console.WriteLine("3 = Start performance test"); - var pressedKey = Console.ReadKey(true); - if (pressedKey.Key == ConsoleKey.D1) - { - Task.Run(() => RunClientAsync(args)); - Thread.Sleep(Timeout.Infinite); - } - else if (pressedKey.Key == ConsoleKey.D2) - { - Task.Run(() => RunServerAsync(args)); - Thread.Sleep(Timeout.Infinite); - } - else if (pressedKey.Key == ConsoleKey.D3) + //var pressedKey = Console.ReadKey(true); + //if (pressedKey.Key == ConsoleKey.D1) + //{ + // Task.Run(() => RunClientAsync(args)); + // Thread.Sleep(Timeout.Infinite); + //} + //else if (pressedKey.Key == ConsoleKey.D2) + //{ + // Task.Run(() => RunServerAsync(args)); + // Thread.Sleep(Timeout.Infinite); + //} + //else if (pressedKey.Key == ConsoleKey.D3) { Task.Run(() => PerformanceTest.RunAsync()); Thread.Sleep(Timeout.Infinite);