diff --git a/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerChannel.cs b/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerChannel.cs index 7325a5f..82bd915 100644 --- a/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerChannel.cs +++ b/Frameworks/MQTTnet.AspnetCore/MqttWebSocketServerChannel.cs @@ -8,22 +8,24 @@ using MQTTnet.Implementations; namespace MQTTnet.AspNetCore { - public class MqttWebSocketServerChannel : IMqttChannel, IDisposable + public class MqttWebSocketServerChannel : IMqttChannel { private WebSocket _webSocket; + private readonly MqttWebSocketChannel _channel; public MqttWebSocketServerChannel(WebSocket webSocket) { _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); - SendStream = new WebSocketStream(_webSocket); + + _channel = new MqttWebSocketChannel(webSocket); ReceiveStream = SendStream; } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } + private Stream SendStream { get; set; } + private Stream ReceiveStream { get; set; } - public Task ConnectAsync() + public Task ConnectAsync(CancellationToken cancellationToken) { return Task.CompletedTask; } @@ -37,7 +39,7 @@ namespace MQTTnet.AspNetCore try { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); } finally { @@ -45,6 +47,16 @@ namespace MQTTnet.AspNetCore } } + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ReceiveStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return SendStream.WriteAsync(buffer, offset, count, cancellationToken); + } + public void Dispose() { SendStream?.Dispose(); diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs index 8f03e65..11879f2 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using MQTTnet.Packets; @@ -11,7 +10,7 @@ namespace MQTTnet.Adapter { IMqttPacketSerializer PacketSerializer { get; } - Task ConnectAsync(TimeSpan timeout); + Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken); Task DisconnectAsync(TimeSpan timeout); diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index 087b104..1d48f23 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -32,12 +32,12 @@ namespace MQTTnet.Adapter public IMqttPacketSerializer PacketSerializer { get; } - public Task ConnectAsync(TimeSpan timeout) + public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); _logger.Verbose("Connecting [Timeout={0}]", timeout); - return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync().TimeoutAfter(timeout)); + return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync(cancellationToken).TimeoutAfter(timeout)); } public Task DisconnectAsync(TimeSpan timeout) @@ -50,44 +50,32 @@ namespace MQTTnet.Adapter public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) { + ThrowIfDisposed(); + foreach (var packet in packets) { - await SendPacketsAsync(timeout, cancellationToken, packet).ConfigureAwait(false); + if (packet == null) + { + continue; + } + + await SendPacketAsync(timeout, cancellationToken, packet).ConfigureAwait(false); } } - private Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) + private Task SendPacketAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) { - ThrowIfDisposed(); - - if (packet == null) + return ExecuteAndWrapExceptionAsync(() => { - return Task.FromResult(0); - } - - return ExecuteAndWrapExceptionAsync(async () => - { - if (cancellationToken.IsCancellationRequested) - { - return; - } - _logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); var packetData = PacketSerializer.Serialize(packet); - if (cancellationToken.IsCancellationRequested) - { - return; - } - - await _channel.SendStream.WriteAsync( + return _channel.WriteAsync( packetData.Array, packetData.Offset, packetData.Count, - cancellationToken).ConfigureAwait(false); - - await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); + cancellationToken); }); } @@ -101,7 +89,6 @@ namespace MQTTnet.Adapter ReceivedMqttPacket receivedMqttPacket = null; try { - if (timeout > TimeSpan.Zero) { var timeoutCts = new CancellationTokenSource(timeout); @@ -109,14 +96,14 @@ namespace MQTTnet.Adapter try { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, linkedCts.Token).ConfigureAwait(false); + receivedMqttPacket = await ReceiveAsync(_channel, linkedCts.Token).ConfigureAwait(false); } - catch (OperationCanceledException ex) + catch (OperationCanceledException exception) { var timedOut = linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; if (timedOut) { - throw new MqttCommunicationTimedOutException(ex); + throw new MqttCommunicationTimedOutException(exception); } else { @@ -126,7 +113,7 @@ namespace MQTTnet.Adapter } else { - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); + receivedMqttPacket = await ReceiveAsync(_channel, cancellationToken).ConfigureAwait(false); } if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) @@ -151,9 +138,9 @@ namespace MQTTnet.Adapter return packet; } - private static async Task ReceiveAsync(Stream stream, CancellationToken cancellationToken) + private static async Task ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) { - var header = await MqttPacketReader.ReadHeaderAsync(stream, cancellationToken).ConfigureAwait(false); + var header = await MqttPacketReader.ReadHeaderAsync(channel, cancellationToken).ConfigureAwait(false); if (header == null) { return null; @@ -166,7 +153,7 @@ namespace MQTTnet.Adapter var body = header.BodyLength <= ReadBufferSize ? new MemoryStream(header.BodyLength) : new MemoryStream(); - var buffer = new byte[ReadBufferSize]; + var buffer = new byte[Math.Min(ReadBufferSize, header.BodyLength)]; while (body.Length < header.BodyLength) { var bytesLeft = header.BodyLength - (int)body.Length; @@ -175,7 +162,7 @@ namespace MQTTnet.Adapter bytesLeft = buffer.Length; } - var readBytesCount = await stream.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); + var readBytesCount = await channel.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); // Check if the client closed the connection before sending the full body. if (readBytesCount == 0) diff --git a/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs b/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs index b332884..dd9d1f2 100644 --- a/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Channel/IMqttChannel.cs @@ -1,15 +1,15 @@ using System; -using System.IO; +using System.Threading; using System.Threading.Tasks; namespace MQTTnet.Channel { public interface IMqttChannel : IDisposable { - Stream SendStream { get; } - Stream ReceiveStream { get; } - - Task ConnectAsync(); + Task ConnectAsync(CancellationToken cancellationToken); Task DisconnectAsync(); + + Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); + Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); } } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs index 90790d5..ee2189e 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs @@ -60,7 +60,7 @@ namespace MQTTnet.Client _adapter = _adapterFactory.CreateClientAdapter(options, _logger); _logger.Verbose("Trying to connect with server."); - await _adapter.ConnectAsync(_options.CommunicationTimeout).ConfigureAwait(false); + await _adapter.ConnectAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token).ConfigureAwait(false); _logger.Verbose("Connection with server established."); await StartReceivingPacketsAsync().ConfigureAwait(false); @@ -92,14 +92,9 @@ namespace MQTTnet.Client public async Task DisconnectAsync() { - if (!IsConnected) - { - return; - } - try { - if (!_cancellationTokenSource.IsCancellationRequested) + if (IsConnected && !_cancellationTokenSource.IsCancellationRequested) { await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); } diff --git a/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs b/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs index ac96ce6..44bc5b4 100644 --- a/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs +++ b/Frameworks/MQTTnet.NetStandard/Client/MqttClientTcpOptions.cs @@ -6,7 +6,7 @@ public int? Port { get; set; } - public int BufferSize { get; set; } = 20 * 4096; + public int BufferSize { get; set; } = 4096; public MqttClientTlsOptions TlsOptions { get; set; } = new MqttClientTlsOptions(); } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs index a9cd18e..1487f0a 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.Uwp.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.InteropServices.WindowsRuntime; +using System.Threading; using System.Threading.Tasks; using Windows.Networking; using Windows.Networking.Sockets; @@ -17,36 +18,35 @@ namespace MQTTnet.Implementations { // ReSharper disable once MemberCanBePrivate.Global // ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global - public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. + public static int BufferSize { get; set; } = 4096; // Can be changed for fine tuning by library user. private readonly int _bufferSize = BufferSize; private readonly MqttClientTcpOptions _options; private StreamSocket _socket; + private Stream _readStream; + private Stream _writeStream; public MqttTcpChannel(MqttClientTcpOptions options) { _options = options ?? throw new ArgumentNullException(nameof(options)); - _bufferSize = _options.BufferSize; + + _bufferSize = options.BufferSize; } public MqttTcpChannel(StreamSocket socket) { _socket = socket ?? throw new ArgumentNullException(nameof(socket)); - - CreateStreams(); } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } - public static Func> CustomIgnorableServerCertificateErrorsResolver { get; set; } - public async Task ConnectAsync() + public async Task ConnectAsync(CancellationToken cancellationToken) { if (_socket == null) { _socket = new StreamSocket(); + _socket.Control.NoDelay = true; } if (!_options.TlsOptions.UseTls) @@ -65,7 +65,8 @@ namespace MQTTnet.Implementations await _socket.ConnectAsync(new HostName(_options.Server), _options.GetPort().ToString(), SocketProtectionLevel.Tls12); } - CreateStreams(); + _readStream = _socket.InputStream.AsStreamForRead(_bufferSize); + _writeStream = _socket.OutputStream.AsStreamForWrite(_bufferSize); } public Task DisconnectAsync() @@ -74,11 +75,22 @@ namespace MQTTnet.Implementations return Task.FromResult(0); } + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _readStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _writeStream.WriteAsync(buffer, offset, count, cancellationToken); + await _writeStream.FlushAsync(cancellationToken); + } + public void Dispose() { try { - SendStream?.Dispose(); + _readStream?.Dispose(); } catch (ObjectDisposedException) { @@ -88,12 +100,12 @@ namespace MQTTnet.Implementations } finally { - SendStream = null; + _readStream = null; } try { - ReceiveStream?.Dispose(); + _writeStream?.Dispose(); } catch (ObjectDisposedException) { @@ -103,7 +115,7 @@ namespace MQTTnet.Implementations } finally { - ReceiveStream = null; + _writeStream = null; } try @@ -122,12 +134,6 @@ namespace MQTTnet.Implementations } } - private void CreateStreams() - { - SendStream = _socket.OutputStream.AsStreamForWrite(_bufferSize); - ReceiveStream = _socket.InputStream.AsStreamForRead(_bufferSize); - } - private static Certificate LoadCertificate(MqttClientTcpOptions options) { if (options.TlsOptions.Certificates == null || !options.TlsOptions.Certificates.Any()) diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index 0f9bba4..fe01781 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -7,6 +7,7 @@ using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using System.IO; using System.Linq; +using System.Threading; using MQTTnet.Channel; using MQTTnet.Client; @@ -14,20 +15,10 @@ namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttChannel { -#if NET452 || NET461 || NETSTANDARD2_0 - // ReSharper disable once MemberCanBePrivate.Global - // ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global - public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. - - private readonly int _bufferSize = BufferSize; -#else - private readonly int _bufferSize = 0; -#endif - private readonly MqttClientTcpOptions _options; private Socket _socket; - private SslStream _sslStream; + private Stream _stream; /// /// called on client sockets are created in connect @@ -35,7 +26,6 @@ namespace MQTTnet.Implementations public MqttTcpChannel(MqttClientTcpOptions options) { _options = options ?? throw new ArgumentNullException(nameof(options)); - _bufferSize = options.BufferSize; } /// @@ -45,21 +35,17 @@ namespace MQTTnet.Implementations public MqttTcpChannel(Socket socket, SslStream sslStream) { _socket = socket ?? throw new ArgumentNullException(nameof(socket)); - _sslStream = sslStream; - CreateStreams(); + CreateStream(sslStream); } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } - public static Func CustomCertificateValidationCallback { get; set; } - public async Task ConnectAsync() + public async Task ConnectAsync(CancellationToken cancellationToken) { if (_socket == null) { - _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + _socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; } #if NET452 || NET461 @@ -68,15 +54,14 @@ namespace MQTTnet.Implementations await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); #endif - _socket.NoDelay = true; - + SslStream sslStream = null; if (_options.TlsOptions.UseTls) { - _sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); - await _sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); + sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); + await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); } - - CreateStreams(); + + CreateStream(sslStream); } public Task DisconnectAsync() @@ -85,46 +70,21 @@ namespace MQTTnet.Implementations return Task.FromResult(0); } - public void Dispose() + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var oneStreamIsUsed = SendStream != null && ReceiveStream != null && ReferenceEquals(SendStream, ReceiveStream); - - try - { - SendStream?.Dispose(); - } - catch (ObjectDisposedException) - { - } - catch (NullReferenceException) - { - } - finally - { - SendStream = null; - } + return _stream.ReadAsync(buffer, offset, count, cancellationToken); + } - try - { - if (!oneStreamIsUsed) - { - ReceiveStream?.Dispose(); - } - } - catch (ObjectDisposedException) - { - } - catch (NullReferenceException) - { - } - finally - { - ReceiveStream = null; - } + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _stream.WriteAsync(buffer, offset, count, cancellationToken); + } + public void Dispose() + { try { - _sslStream?.Dispose(); + _stream?.Dispose(); } catch (ObjectDisposedException) { @@ -134,7 +94,7 @@ namespace MQTTnet.Implementations } finally { - _sslStream = null; + _stream = null; } try @@ -200,25 +160,16 @@ namespace MQTTnet.Implementations return certificates; } - private void CreateStreams() + private void CreateStream(Stream stream) { - Stream stream; - if (_sslStream != null) + if (stream != null) { - stream = _sslStream; + _stream = stream; } else { - stream = new NetworkStream(_socket, true); + _stream = new NetworkStream(_socket, true); } - -#if NET452 || NET461 || NETSTANDARD2_0 - SendStream = new BufferedStream(stream, _bufferSize); - ReceiveStream = new BufferedStream(stream, _bufferSize); -#else - SendStream = stream; - ReceiveStream = stream; -#endif } } } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs index e71d6ff..0d84ab4 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.Uwp.cs @@ -62,6 +62,8 @@ namespace MQTTnet.Implementations { try { + args.Socket.Control.NoDelay = true; + var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(args.Socket), new MqttPacketSerializer(), _logger); ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs index b718dbb..c9f38ab 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs @@ -43,7 +43,7 @@ namespace MQTTnet.Implementations _defaultEndpointSocket.Bind(new IPEndPoint(options.DefaultEndpointOptions.BoundIPAddress, options.GetDefaultEndpointPort())); _defaultEndpointSocket.Listen(options.ConnectionBacklog); - Task.Run(async () => await AcceptDefaultEndpointConnectionsAsync(_cancellationTokenSource.Token).ConfigureAwait(false), _cancellationTokenSource.Token).ConfigureAwait(false); + Task.Run(() => AcceptDefaultEndpointConnectionsAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); } if (options.TlsEndpointOptions.IsEnabled) @@ -63,7 +63,7 @@ namespace MQTTnet.Implementations _tlsEndpointSocket.Bind(new IPEndPoint(options.TlsEndpointOptions.BoundIPAddress, options.GetTlsEndpointPort())); _tlsEndpointSocket.Listen(options.ConnectionBacklog); - Task.Run(async () => await AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token).ConfigureAwait(false), _cancellationTokenSource.Token).ConfigureAwait(false); + Task.Run(() => AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); } return Task.FromResult(0); diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs index 057a07e..4a9fe65 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs @@ -1,11 +1,13 @@ using System; -using System.IO; +using System.Collections.Generic; +using System.Linq; using System.Net.WebSockets; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; using MQTTnet.Client; +using MQTTnet.Exceptions; namespace MQTTnet.Implementations { @@ -13,20 +15,25 @@ namespace MQTTnet.Implementations { // ReSharper disable once MemberCanBePrivate.Global // ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global - public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. + public static int BufferSize { get; set; } = 4096; // Can be changed for fine tuning by library user. + private readonly byte[] _chunckBuffer = new byte[BufferSize]; + private readonly Queue _buffer = new Queue(BufferSize); private readonly MqttClientWebSocketOptions _options; - private ClientWebSocket _webSocket; + + private WebSocket _webSocket; public MqttWebSocketChannel(MqttClientWebSocketOptions options) { _options = options ?? throw new ArgumentNullException(nameof(options)); } - public Stream SendStream { get; private set; } - public Stream ReceiveStream { get; private set; } + public MqttWebSocketChannel(WebSocket webSocket) + { + _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); + } - public async Task ConnectAsync() + public async Task ConnectAsync(CancellationToken cancellationToken) { var uri = _options.Uri; if (!uri.StartsWith("ws://", StringComparison.OrdinalIgnoreCase) && !uri.StartsWith("wss://", StringComparison.OrdinalIgnoreCase)) @@ -41,13 +48,13 @@ namespace MQTTnet.Implementations } } - _webSocket = new ClientWebSocket(); + var clientWebSocket = new ClientWebSocket(); if (_options.RequestHeaders != null) { foreach (var requestHeader in _options.RequestHeaders) { - _webSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); + clientWebSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); } } @@ -55,28 +62,26 @@ namespace MQTTnet.Implementations { foreach (var subProtocol in _options.SubProtocols) { - _webSocket.Options.AddSubProtocol(subProtocol); + clientWebSocket.Options.AddSubProtocol(subProtocol); } } if (_options.CookieContainer != null) { - _webSocket.Options.Cookies = _options.CookieContainer; + clientWebSocket.Options.Cookies = _options.CookieContainer; } if (_options.TlsOptions?.UseTls == true && _options.TlsOptions?.Certificates != null) { - _webSocket.Options.ClientCertificates = new X509CertificateCollection(); + clientWebSocket.Options.ClientCertificates = new X509CertificateCollection(); foreach (var certificate in _options.TlsOptions.Certificates) { - _webSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); + clientWebSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); } } - await _webSocket.ConnectAsync(new Uri(uri), CancellationToken.None).ConfigureAwait(false); - - SendStream = new WebSocketStream(_webSocket); - ReceiveStream = SendStream; + await clientWebSocket.ConnectAsync(new Uri(uri), cancellationToken).ConfigureAwait(false); + _webSocket = clientWebSocket; } public async Task DisconnectAsync() @@ -94,6 +99,56 @@ namespace MQTTnet.Implementations Dispose(); } + public async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var bytesRead = 0; + + // Use existing date from buffer. + while (count > 0 && _buffer.Any()) + { + buffer[offset] = _buffer.Dequeue(); + count--; + bytesRead++; + offset++; + } + + if (count == 0) + { + return bytesRead; + } + + // Fetch new data if the buffer is not full. + while (_webSocket.State == WebSocketState.Open) + { + await FetchChunkAsync(cancellationToken).ConfigureAwait(false); + + while (count > 0 && _buffer.Any()) + { + buffer[offset] = _buffer.Dequeue(); + count--; + bytesRead++; + offset++; + } + + if (count == 0) + { + return bytesRead; + } + } + + if (_webSocket.State == WebSocketState.Closed) + { + throw new MqttCommunicationException("WebSocket connection closed."); + } + + return bytesRead; + } + + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); + } + public void Dispose() { try @@ -108,5 +163,26 @@ namespace MQTTnet.Implementations _webSocket = null; } } + + private async Task FetchChunkAsync(CancellationToken cancellationToken) + { + var response = await _webSocket.ReceiveAsync(new ArraySegment(_chunckBuffer, 0, _chunckBuffer.Length), cancellationToken).ConfigureAwait(false); + + if (response.MessageType == WebSocketMessageType.Close) + { + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); + } + else if (response.MessageType == WebSocketMessageType.Binary) + { + for (var i = 0; i < response.Count; i++) + { + _buffer.Enqueue(_chunckBuffer[i]); + } + } + else if (response.MessageType == WebSocketMessageType.Text) + { + throw new MqttProtocolViolationException("WebSocket channel received TEXT message."); + } + } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs b/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs deleted file mode 100644 index bad0393..0000000 --- a/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs +++ /dev/null @@ -1,137 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Net.WebSockets; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Exceptions; - -namespace MQTTnet.Implementations -{ - public class WebSocketStream : Stream - { - private readonly byte[] _chunckBuffer = new byte[MqttWebSocketChannel.BufferSize]; - private readonly Queue _buffer = new Queue(MqttWebSocketChannel.BufferSize); - private readonly WebSocket _webSocket; - - public WebSocketStream(WebSocket webSocket) - { - _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); - } - - public override bool CanRead => true; - - public override bool CanSeek => false; - - public override bool CanWrite => true; - - public override long Length => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - - public override void Flush() - { - } - - public override Task FlushAsync(CancellationToken cancellationToken) - { - return Task.FromResult(0); - } - - public override int Read(byte[] buffer, int offset, int count) - { - return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - var bytesRead = 0; - - // Use existing date from buffer. - while (count > 0 && _buffer.Any()) - { - buffer[offset] = _buffer.Dequeue(); - count--; - bytesRead++; - offset++; - } - - if (count == 0) - { - return bytesRead; - } - - // Fetch new data if the buffer is not full. - while (_webSocket.State == WebSocketState.Open) - { - await FetchChunkAsync(cancellationToken).ConfigureAwait(false); - - while (count > 0 && _buffer.Any()) - { - buffer[offset] = _buffer.Dequeue(); - count--; - bytesRead++; - offset++; - } - - if (count == 0) - { - return bytesRead; - } - } - - if (_webSocket.State == WebSocketState.Closed) - { - throw new MqttCommunicationException("WebSocket connection closed."); - } - - return bytesRead; - } - - public override void Write(byte[] buffer, int offset, int count) - { - WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - return _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } - - public override void SetLength(long value) - { - throw new NotSupportedException(); - } - - private async Task FetchChunkAsync(CancellationToken cancellationToken) - { - var response = await _webSocket.ReceiveAsync(new ArraySegment(_chunckBuffer, 0, _chunckBuffer.Length), cancellationToken).ConfigureAwait(false); - - if (response.MessageType == WebSocketMessageType.Close) - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); - } - else if (response.MessageType == WebSocketMessageType.Binary) - { - for (var i = 0; i < response.Count; i++) - { - _buffer.Enqueue(_chunckBuffer[i]); - } - } - else if (response.MessageType == WebSocketMessageType.Text) - { - throw new MqttProtocolViolationException("WebSocket channel received TEXT message."); - } - } - } -} diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs index 6afdd98..3688330 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.IO; using MQTTnet.Packets; @@ -11,6 +10,6 @@ namespace MQTTnet.Serializer ArraySegment Serialize(MqttBasePacket mqttPacket); - MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); + MqttBasePacket Deserialize(MqttPacketHeader header, Stream body); } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs index bd43f7d..4c43579 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs @@ -1,9 +1,9 @@ using System; -using System.Collections.Generic; using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; +using MQTTnet.Channel; using MQTTnet.Exceptions; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -13,7 +13,7 @@ namespace MQTTnet.Serializer public sealed class MqttPacketReader : BinaryReader { private readonly MqttPacketHeader _header; - + public MqttPacketReader(MqttPacketHeader header, Stream bodyStream) : base(bodyStream, Encoding.UTF8, true) { @@ -22,7 +22,7 @@ namespace MQTTnet.Serializer public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; - public static async Task ReadHeaderAsync(Stream stream, CancellationToken cancellationToken) + public static async Task ReadHeaderAsync(IMqttChannel stream, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { @@ -33,7 +33,7 @@ namespace MQTTnet.Serializer // some large delay and thus the thread should be put back to the pool (await). So ReadByte() // is not an option here. var buffer = new byte[1]; - var readCount = await stream.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); + var readCount = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); if (readCount <= 0) { return null; @@ -89,15 +89,14 @@ namespace MQTTnet.Serializer return ReadBytes(_header.BodyLength - (int)BaseStream.Position); } - private static async Task ReadBodyLengthAsync(Stream stream, CancellationToken cancellationToken) + private static async Task ReadBodyLengthAsync(IMqttChannel stream, CancellationToken cancellationToken) { // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. var multiplier = 1; var value = 0; - byte encodedByte; - + int encodedByte; var buffer = new byte[1]; - var readBytes = new List(); + do { if (cancellationToken.IsCancellationRequested) @@ -112,12 +111,11 @@ namespace MQTTnet.Serializer } encodedByte = buffer[0]; - readBytes.Add(encodedByte); value += (byte)(encodedByte & 127) * multiplier; if (multiplier > 128 * 128 * 128) { - throw new MqttProtocolViolationException($"Remaining length is invalid (Data={string.Join(",", readBytes)})."); + throw new MqttProtocolViolationException("Remaining length is invalid."); } multiplier *= 128; diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 2447e05..8225e60 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -27,7 +27,7 @@ namespace MQTTnet.Serializer var fixedHeader = SerializePacket(packet, writer); - var remainingLength = MqttPacketWriter.GetRemainingLength((int)stream.Length - 5); + var remainingLength = MqttPacketWriter.EncodeRemainingLength((int)stream.Length - 5); var headerSize = remainingLength.Length + 1; var headerOffset = 5 - headerSize; @@ -47,7 +47,7 @@ namespace MQTTnet.Serializer } } - public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) + public MqttBasePacket Deserialize(MqttPacketHeader header, Stream body) { if (header == null) throw new ArgumentNullException(nameof(header)); if (body == null) throw new ArgumentNullException(nameof(body)); @@ -183,7 +183,7 @@ namespace MQTTnet.Serializer var topic = reader.ReadStringWithLengthPrefix(); - ushort packetIdentifier = 0; + ushort? packetIdentifier = null; if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { packetIdentifier = reader.ReadUInt16(); @@ -191,12 +191,12 @@ namespace MQTTnet.Serializer var packet = new MqttPublishPacket { + PacketIdentifier = packetIdentifier, Retain = retain, - QualityOfServiceLevel = qualityOfServiceLevel, - Dup = dup, Topic = topic, Payload = reader.ReadRemainingData(), - PacketIdentifier = packetIdentifier + QualityOfServiceLevel = qualityOfServiceLevel, + Dup = dup }; return packet; diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs index 253294b..3de4961 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs @@ -56,7 +56,7 @@ namespace MQTTnet.Serializer Write(value); } - public static byte[] GetRemainingLength(int length) + public static byte[] EncodeRemainingLength(int length) { if (length <= 0) { @@ -82,7 +82,8 @@ namespace MQTTnet.Serializer offset++; } while (x > 0); - return bytes.Take(offset).ToArray(); + Array.Resize(ref bytes, offset); + return bytes; } } } diff --git a/README.md b/README.md index 83a03a3..313fc72 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov * TLS 1.2 support for client and server (but not UWP servers) * Extensible communication channels (i.e. In-Memory, TCP, TCP+TLS, WS) * Lightweight (only the low level implementation of MQTT, no overhead) -* Performance optimized (processing ~40.000 messages / second)* +* Performance optimized (processing ~50.000 messages / second)* * Interfaces included for mocking and testing * Access to internal trace messages * Unit tested (~80 tests) diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index 4d28035..ca48ad9 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -11,8 +11,7 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttPacketReader_EmptyStream() { - var memStream = new MemoryStream(); - var header = MqttPacketReader.ReadHeaderAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); + var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); Assert.IsNull(header); } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index a7421c6..276475a 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -416,7 +416,7 @@ namespace MQTTnet.Core.Tests using (var headerStream = new MemoryStream(Join(buffer1))) { - var header = MqttPacketReader.ReadHeaderAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); + var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(headerStream), CancellationToken.None).GetAwaiter().GetResult(); using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.BodyLength)) { diff --git a/Tests/MQTTnet.Core.Tests/TestMqttChannel.cs b/Tests/MQTTnet.Core.Tests/TestMqttChannel.cs new file mode 100644 index 0000000..2b4914b --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/TestMqttChannel.cs @@ -0,0 +1,41 @@ +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Channel; + +namespace MQTTnet.Core.Tests +{ + public class TestMqttChannel : IMqttChannel + { + private readonly MemoryStream _stream; + + public TestMqttChannel(MemoryStream stream) + { + _stream = stream; + } + + public void Dispose() + { + } + + public Task ConnectAsync(CancellationToken cancellationToken) + { + return Task.FromResult(0); + } + + public Task DisconnectAsync() + { + return Task.FromResult(0); + } + + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _stream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _stream.WriteAsync(buffer, offset, count, cancellationToken); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index f564e51..6c1bec5 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Concurrent; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -21,7 +20,7 @@ namespace MQTTnet.Core.Tests { } - public Task ConnectAsync(TimeSpan timeout) + public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { return Task.FromResult(0); } diff --git a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs index b485533..878a788 100644 --- a/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/PerformanceTest.cs @@ -59,7 +59,7 @@ namespace MQTTnet.TestApp.NetCore var sentMessagesCount = 0; while (stopwatch.ElapsedMilliseconds < 1000) { - await client.PublishAsync(messages).ConfigureAwait(false); + client.PublishAsync(messages).GetAwaiter().GetResult(); sentMessagesCount++; }