diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 78c44f9..84e44b6 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -11,7 +11,7 @@ false MQTTnet is a .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker). * [Client] Added support for web socket communication channel (thanks to nowakpiotr) -* [Core] Performance optimizations (thanks to JanEggers) +* [Core] Huge performance optimizations (thanks to JanEggers) Copyright Christian Kratky 2016-2017 MQTT MQTTClient MQTTServer MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Queue Hardware Arduino diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs index 47dfb45..be6f778 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttServerAdapter.cs @@ -10,7 +10,6 @@ using MQTTnet.Core.Adapter; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Serializer; using MQTTnet.Core.Server; -using MQTTnet.Core.Channel; namespace MQTTnet.Implementations { @@ -87,7 +86,9 @@ namespace MQTTnet.Implementations try { var clientSocket = await Task.Factory.FromAsync(_defaultEndpointSocket.BeginAccept, _defaultEndpointSocket.EndAccept, null).ConfigureAwait(false); - var clientAdapter = new MqttChannelCommunicationAdapter(new BufferedCommunicationChannel(new MqttTcpChannel(clientSocket, null)), new MqttPacketSerializer()); + + var tcpChannel = new MqttTcpChannel(clientSocket, null); + var clientAdapter = new MqttChannelCommunicationAdapter(tcpChannel, new MqttPacketSerializer()); ClientConnected?.Invoke(this, new MqttClientConnectedEventArgs(clientSocket.RemoteEndPoint.ToString(), clientAdapter)); } catch (Exception exception) when (!(exception is ObjectDisposedException)) @@ -110,8 +111,9 @@ namespace MQTTnet.Implementations var sslStream = new SslStream(new NetworkStream(clientSocket)); await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, SslProtocols.Tls12, false).ConfigureAwait(false); - - var clientAdapter = new MqttChannelCommunicationAdapter(new MqttTcpChannel(clientSocket, sslStream), new MqttPacketSerializer()); + + var tcpChannel = new MqttTcpChannel(clientSocket, sslStream); + var clientAdapter = new MqttChannelCommunicationAdapter(tcpChannel, new MqttPacketSerializer()); ClientConnected?.Invoke(this, new MqttClientConnectedEventArgs(clientSocket.RemoteEndPoint.ToString(), clientAdapter)); } catch (Exception exception) diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs index 0ecf373..33ba567 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs @@ -13,16 +13,19 @@ namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable { - private Stream _dataStream; private Socket _socket; private SslStream _sslStream; + public Stream RawStream { get; private set; } + public Stream SendStream { get; private set; } + public Stream ReceiveStream { get; private set; } + /// /// called on client sockets are created in connect /// public MqttTcpChannel() { - + } /// @@ -33,7 +36,7 @@ namespace MQTTnet.Implementations { _socket = socket ?? throw new ArgumentNullException(nameof(socket)); _sslStream = sslStream; - _dataStream = (Stream)sslStream ?? new NetworkStream(socket); + CreateCommStreams(socket, sslStream); } public async Task ConnectAsync(MqttClientOptions options) @@ -51,14 +54,11 @@ namespace MQTTnet.Implementations if (options.TlsOptions.UseTls) { _sslStream = new SslStream(new NetworkStream(_socket, true)); - - _dataStream = _sslStream; + await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); } - else - { - _dataStream = new NetworkStream(_socket); - } + + CreateCommStreams(_socket, _sslStream); } catch (SocketException exception) { @@ -79,45 +79,6 @@ namespace MQTTnet.Implementations } } - public async Task WriteAsync(byte[] buffer) - { - if (buffer == null) throw new ArgumentNullException(nameof(buffer)); - - try - { - await _dataStream.WriteAsync(buffer, 0, buffer.Length); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); - } - } - - public async Task> ReadAsync(int length, byte[] buffer) - { - try - { - var totalBytes = 0; - - do - { - var read = await _dataStream.ReadAsync(buffer, totalBytes, length - totalBytes).ConfigureAwait(false); - if (read == 0) - { - throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); - } - - totalBytes += read; - } - while (totalBytes < length); - return new ArraySegment(buffer, 0, length); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); - } - } - public void Dispose() { _socket?.Dispose(); @@ -127,6 +88,16 @@ namespace MQTTnet.Implementations _sslStream = null; } + private void CreateCommStreams(Socket socket, SslStream sslStream) + { + RawStream = (Stream)sslStream ?? new NetworkStream(socket); + + //cannot use this as default buffering prevents from receiving the first connect message + //need two streams otherwise read and write have to be synchronized + SendStream = new BufferedStream(RawStream, BufferConstants.Size); + ReceiveStream = new BufferedStream(RawStream, BufferConstants.Size); + } + private static X509CertificateCollection LoadCertificates(MqttClientOptions options) { var certificates = new X509CertificateCollection(); @@ -142,10 +113,5 @@ namespace MQTTnet.Implementations return certificates; } - - public int Peek() - { - return _socket.Available; - } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs index 925c818..bc224cb 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs @@ -2,6 +2,7 @@ using MQTTnet.Core.Client; using MQTTnet.Core.Exceptions; using System; +using System.IO; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; @@ -11,8 +12,11 @@ namespace MQTTnet.Implementations public sealed class MqttWebSocketChannel : IMqttCommunicationChannel, IDisposable { private ClientWebSocket _webSocket = new ClientWebSocket(); - private int WebSocketBufferSize; - private int WebSocketBufferOffset; + + public Stream RawStream { get; private set; } + + public Stream SendStream => RawStream; + public Stream ReceiveStream => RawStream; public async Task ConnectAsync(MqttClientOptions options) { @@ -22,6 +26,8 @@ namespace MQTTnet.Implementations { _webSocket = new ClientWebSocket(); await _webSocket.ConnectAsync(new Uri(options.Server), CancellationToken.None); + + RawStream = new WebSocketStream(_webSocket); } catch (WebSocketException exception) { @@ -31,6 +37,7 @@ namespace MQTTnet.Implementations public Task DisconnectAsync() { + RawStream = null; return _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); } @@ -38,62 +45,5 @@ namespace MQTTnet.Implementations { _webSocket?.Dispose(); } - - public async Task> ReadAsync(int length, byte[] buffer) - { - await ReadToBufferAsync(length, buffer).ConfigureAwait(false); - - var result = new ArraySegment(buffer, WebSocketBufferOffset, length); - WebSocketBufferSize -= length; - WebSocketBufferOffset += length; - - return result; - } - - private async Task ReadToBufferAsync(int length, byte[] buffer) - { - if (WebSocketBufferSize > 0) - { - return; - } - - var offset = 0; - while (_webSocket.State == WebSocketState.Open && WebSocketBufferSize < length) - { - WebSocketReceiveResult response; - do - { - response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, offset, buffer.Length - offset), CancellationToken.None).ConfigureAwait(false); - offset += response.Count; - } while (!response.EndOfMessage); - - WebSocketBufferSize = response.Count; - if (response.MessageType == WebSocketMessageType.Close) - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); - } - } - } - - public async Task WriteAsync(byte[] buffer) - { - if (buffer == null) { - throw new ArgumentNullException(nameof(buffer)); - } - - try - { - await _webSocket.SendAsync(new ArraySegment(buffer), WebSocketMessageType.Binary, true, CancellationToken.None); - } - catch (WebSocketException exception) - { - throw new MqttCommunicationException(exception); - } - } - - public int Peek() - { - return WebSocketBufferSize; - } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs b/Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs new file mode 100644 index 0000000..d912148 --- /dev/null +++ b/Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs @@ -0,0 +1,80 @@ +using System; +using System.IO; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Implementations +{ + public class WebSocketStream : Stream + { + private readonly ClientWebSocket _webSocket; + + public WebSocketStream(ClientWebSocket webSocket) + { + _webSocket = webSocket; + } + + public override void Flush() + { + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var currentOffset = offset; + var targetOffset = offset + count; + while (_webSocket.State == WebSocketState.Open && currentOffset < targetOffset) + { + var response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, currentOffset, count), cancellationToken).ConfigureAwait(false); + currentOffset += response.Count; + + if (response.MessageType == WebSocketMessageType.Close) + { + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); + } + } + + return currentOffset - offset; + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public override long Length + { + get { throw new NotSupportedException(); } + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + } +} diff --git a/Frameworks/MQTTnet.NetFramework/MQTTnet.NetFramework.csproj b/Frameworks/MQTTnet.NetFramework/MQTTnet.NetFramework.csproj index 7578afe..3f2519f 100644 --- a/Frameworks/MQTTnet.NetFramework/MQTTnet.NetFramework.csproj +++ b/Frameworks/MQTTnet.NetFramework/MQTTnet.NetFramework.csproj @@ -101,6 +101,7 @@ + diff --git a/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs b/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs index b3e0080..8075f92 100644 --- a/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs +++ b/Frameworks/MQTTnet.NetFramework/MqttClientFactory.cs @@ -22,7 +22,7 @@ namespace MQTTnet { case MqttConnectionType.Tcp: case MqttConnectionType.Tls: - return new BufferedCommunicationChannel( new MqttTcpChannel() ); + return new MqttTcpChannel(); case MqttConnectionType.Ws: case MqttConnectionType.Wss: return new MqttWebSocketChannel(); diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs index 6b02648..112d737 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttServerAdapter.cs @@ -27,6 +27,7 @@ namespace MQTTnet.Implementations public void Start(MqttServerOptions options) { if (_isRunning) throw new InvalidOperationException("Server is already started."); + _isRunning = true; _cancellationTokenSource = new CancellationTokenSource(); @@ -107,7 +108,7 @@ namespace MQTTnet.Implementations var sslStream = new SslStream(new NetworkStream(clientSocket)); await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, SslProtocols.Tls12, false).ConfigureAwait(false); - + var clientAdapter = new MqttChannelCommunicationAdapter(new MqttTcpChannel(clientSocket, sslStream), new MqttPacketSerializer()); ClientConnected?.Invoke(this, new MqttClientConnectedEventArgs(clientSocket.RemoteEndPoint.ToString(), clientAdapter)); } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index bddd79d..0ef5dad 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -13,9 +13,12 @@ namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable { - private Stream _dataStream; private Socket _socket; private SslStream _sslStream; + + public Stream ReceiveStream { get; private set; } + public Stream RawStream => ReceiveStream; + public Stream SendStream => ReceiveStream; /// /// called on client sockets are created in connect @@ -32,7 +35,7 @@ namespace MQTTnet.Implementations { _socket = socket ?? throw new ArgumentNullException(nameof(socket)); _sslStream = sslStream; - _dataStream = (Stream)sslStream ?? new NetworkStream(socket); + ReceiveStream = (Stream)sslStream ?? new NetworkStream(socket); } public async Task ConnectAsync(MqttClientOptions options) @@ -51,12 +54,12 @@ namespace MQTTnet.Implementations if (options.TlsOptions.UseTls) { _sslStream = new SslStream(new NetworkStream(_socket, true)); - _dataStream = _sslStream; + ReceiveStream = _sslStream; await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); } else { - _dataStream = new NetworkStream(_socket); + ReceiveStream = new NetworkStream(_socket); } } catch (SocketException exception) @@ -78,45 +81,6 @@ namespace MQTTnet.Implementations } } - public async Task WriteAsync(byte[] buffer) - { - if (buffer == null) throw new ArgumentNullException(nameof(buffer)); - - try - { - await _dataStream.WriteAsync(buffer, 0, buffer.Length); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); - } - } - - public async Task> ReadAsync(int length, byte[] buffer) - { - try - { - var totalBytes = 0; - - do - { - var read = await _dataStream.ReadAsync(buffer, totalBytes, length - totalBytes).ConfigureAwait(false); - if (read == 0) - { - throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); - } - - totalBytes += read; - } - while (totalBytes < length); - return new ArraySegment(buffer, 0, length); - } - catch (SocketException exception) - { - throw new MqttCommunicationException(exception); - } - } - public void Dispose() { _socket?.Dispose(); @@ -141,10 +105,5 @@ namespace MQTTnet.Implementations return certificates; } - - public int Peek() - { - return _socket.Available; - } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs index 5a89ac7..e452cda 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttWebSocketChannel.cs @@ -2,6 +2,7 @@ using MQTTnet.Core.Client; using MQTTnet.Core.Exceptions; using System; +using System.IO; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; @@ -11,8 +12,10 @@ namespace MQTTnet.Implementations public sealed class MqttWebSocketChannel : IMqttCommunicationChannel, IDisposable { private ClientWebSocket _webSocket = new ClientWebSocket(); - private int _bufferSize; - private int _bufferOffset; + + public Stream SendStream => RawStream; + public Stream ReceiveStream => RawStream; + public Stream RawStream { get; private set; } public async Task ConnectAsync(MqttClientOptions options) { @@ -22,6 +25,8 @@ namespace MQTTnet.Implementations { _webSocket = new ClientWebSocket(); await _webSocket.ConnectAsync(new Uri(options.Server), CancellationToken.None); + + RawStream = new WebSocketStream(_webSocket); } catch (WebSocketException exception) { @@ -31,6 +36,7 @@ namespace MQTTnet.Implementations public Task DisconnectAsync() { + RawStream = null; return _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); } @@ -38,64 +44,5 @@ namespace MQTTnet.Implementations { _webSocket?.Dispose(); } - - public async Task> ReadAsync(int length, byte[] buffer) - { - await ReadToBufferAsync(length, buffer).ConfigureAwait(false); - - var result = new ArraySegment(buffer, _bufferOffset, length); - _bufferSize -= length; - _bufferOffset += length; - - return result; - } - - private async Task ReadToBufferAsync(int length, byte[] buffer) - { - if (_bufferSize > 0) - { - return; - } - - var offset = 0; - while (_webSocket.State == WebSocketState.Open && _bufferSize < length) - { - WebSocketReceiveResult response; - do - { - response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, offset, buffer.Length - offset), CancellationToken.None).ConfigureAwait(false); - offset += response.Count; - } while (!response.EndOfMessage); - - _bufferSize = response.Count; - - if (response.MessageType == WebSocketMessageType.Close) - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); - } - } - } - - public async Task WriteAsync(byte[] buffer) - { - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } - - try - { - await _webSocket.SendAsync(new ArraySegment(buffer), WebSocketMessageType.Binary, true, CancellationToken.None); - } - catch (WebSocketException exception) - { - throw new MqttCommunicationException(exception); - } - } - - public int Peek() - { - return _bufferSize; - } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs b/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs new file mode 100644 index 0000000..d912148 --- /dev/null +++ b/Frameworks/MQTTnet.NetStandard/Implementations/WebSocketStream.cs @@ -0,0 +1,80 @@ +using System; +using System.IO; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Implementations +{ + public class WebSocketStream : Stream + { + private readonly ClientWebSocket _webSocket; + + public WebSocketStream(ClientWebSocket webSocket) + { + _webSocket = webSocket; + } + + public override void Flush() + { + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var currentOffset = offset; + var targetOffset = offset + count; + while (_webSocket.State == WebSocketState.Open && currentOffset < targetOffset) + { + var response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, currentOffset, count), cancellationToken).ConfigureAwait(false); + currentOffset += response.Count; + + if (response.MessageType == WebSocketMessageType.Close) + { + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); + } + } + + return currentOffset - offset; + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public override long Length + { + get { throw new NotSupportedException(); } + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + } +} diff --git a/Frameworks/MQTTnet.NetStandard/MqttClientFactory.cs b/Frameworks/MQTTnet.NetStandard/MqttClientFactory.cs index 5857075..25ffd8c 100644 --- a/Frameworks/MQTTnet.NetStandard/MqttClientFactory.cs +++ b/Frameworks/MQTTnet.NetStandard/MqttClientFactory.cs @@ -11,9 +11,7 @@ namespace MQTTnet { public IMqttClient CreateMqttClient(MqttClientOptions options) { - if (options == null) { - throw new ArgumentNullException(nameof(options)); - } + if (options == null) throw new ArgumentNullException(nameof(options)); return new MqttClient(options, new MqttChannelCommunicationAdapter(GetMqttCommunicationChannel(options), new MqttPacketSerializer())); } diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs index 8831198..2258ab3 100644 --- a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttServerAdapter.cs @@ -10,7 +10,7 @@ namespace MQTTnet.Implementations public class MqttServerAdapter : IMqttServerAdapter, IDisposable { private StreamSocketListener _defaultEndpointSocket; - + private bool _isRunning; public event EventHandler ClientConnected; @@ -20,13 +20,14 @@ namespace MQTTnet.Implementations if (options == null) throw new ArgumentNullException(nameof(options)); if (_isRunning) throw new InvalidOperationException("Server is already started."); + _isRunning = true; if (options.DefaultEndpointOptions.IsEnabled) { _defaultEndpointSocket = new StreamSocketListener(); _defaultEndpointSocket.BindServiceNameAsync(options.GetDefaultEndpointPort().ToString(), SocketProtectionLevel.PlainSocket).GetAwaiter().GetResult(); - _defaultEndpointSocket.ConnectionReceived += AcceptDefaultEndpointConnectionsAsync; + _defaultEndpointSocket.ConnectionReceived += AcceptDefaultEndpointConnectionsAsync; } if (options.TlsEndpointOptions.IsEnabled) diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs index 5a89ac7..bc224cb 100644 --- a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs @@ -2,6 +2,7 @@ using MQTTnet.Core.Client; using MQTTnet.Core.Exceptions; using System; +using System.IO; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; @@ -11,8 +12,11 @@ namespace MQTTnet.Implementations public sealed class MqttWebSocketChannel : IMqttCommunicationChannel, IDisposable { private ClientWebSocket _webSocket = new ClientWebSocket(); - private int _bufferSize; - private int _bufferOffset; + + public Stream RawStream { get; private set; } + + public Stream SendStream => RawStream; + public Stream ReceiveStream => RawStream; public async Task ConnectAsync(MqttClientOptions options) { @@ -22,6 +26,8 @@ namespace MQTTnet.Implementations { _webSocket = new ClientWebSocket(); await _webSocket.ConnectAsync(new Uri(options.Server), CancellationToken.None); + + RawStream = new WebSocketStream(_webSocket); } catch (WebSocketException exception) { @@ -31,6 +37,7 @@ namespace MQTTnet.Implementations public Task DisconnectAsync() { + RawStream = null; return _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); } @@ -38,64 +45,5 @@ namespace MQTTnet.Implementations { _webSocket?.Dispose(); } - - public async Task> ReadAsync(int length, byte[] buffer) - { - await ReadToBufferAsync(length, buffer).ConfigureAwait(false); - - var result = new ArraySegment(buffer, _bufferOffset, length); - _bufferSize -= length; - _bufferOffset += length; - - return result; - } - - private async Task ReadToBufferAsync(int length, byte[] buffer) - { - if (_bufferSize > 0) - { - return; - } - - var offset = 0; - while (_webSocket.State == WebSocketState.Open && _bufferSize < length) - { - WebSocketReceiveResult response; - do - { - response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, offset, buffer.Length - offset), CancellationToken.None).ConfigureAwait(false); - offset += response.Count; - } while (!response.EndOfMessage); - - _bufferSize = response.Count; - - if (response.MessageType == WebSocketMessageType.Close) - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); - } - } - } - - public async Task WriteAsync(byte[] buffer) - { - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } - - try - { - await _webSocket.SendAsync(new ArraySegment(buffer), WebSocketMessageType.Binary, true, CancellationToken.None); - } - catch (WebSocketException exception) - { - throw new MqttCommunicationException(exception); - } - } - - public int Peek() - { - return _bufferSize; - } } } \ No newline at end of file diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/WebSocketStream.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/WebSocketStream.cs new file mode 100644 index 0000000..d912148 --- /dev/null +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/WebSocketStream.cs @@ -0,0 +1,80 @@ +using System; +using System.IO; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Implementations +{ + public class WebSocketStream : Stream + { + private readonly ClientWebSocket _webSocket; + + public WebSocketStream(ClientWebSocket webSocket) + { + _webSocket = webSocket; + } + + public override void Flush() + { + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var currentOffset = offset; + var targetOffset = offset + count; + while (_webSocket.State == WebSocketState.Open && currentOffset < targetOffset) + { + var response = await _webSocket.ReceiveAsync(new ArraySegment(buffer, currentOffset, count), cancellationToken).ConfigureAwait(false); + currentOffset += response.Count; + + if (response.MessageType == WebSocketMessageType.Close) + { + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); + } + } + + return currentOffset - offset; + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public override long Length + { + get { throw new NotSupportedException(); } + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + } +} diff --git a/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj b/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj index 3dd8a2d..8767e24 100644 --- a/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj +++ b/Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj @@ -116,6 +116,7 @@ + diff --git a/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs b/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs index 77ab898..d16579e 100644 --- a/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Threading.Tasks; using MQTTnet.Core.Client; using MQTTnet.Core.Packets; @@ -12,7 +13,7 @@ namespace MQTTnet.Core.Adapter Task DisconnectAsync(); - Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout); + Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets); Task ReceivePacketAsync(TimeSpan timeout); diff --git a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs index d0f1c6e..241bd54 100644 --- a/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs +++ b/MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs @@ -1,9 +1,12 @@ using System; +using System.Collections.Generic; +using System.IO; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using MQTTnet.Core.Serializer; @@ -12,6 +15,9 @@ namespace MQTTnet.Core.Adapter public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter { private readonly IMqttCommunicationChannel _channel; + private readonly byte[] _readBuffer = new byte[BufferConstants.Size]; + + private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write public MqttChannelCommunicationAdapter(IMqttCommunicationChannel channel, IMqttPacketSerializer serializer) { @@ -23,7 +29,7 @@ namespace MQTTnet.Core.Adapter public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) { - return ExecuteWithTimeoutAsync(_channel.ConnectAsync(options), timeout); + return _channel.ConnectAsync(options).TimeoutAfter(timeout); } public Task DisconnectAsync() @@ -31,25 +37,38 @@ namespace MQTTnet.Core.Adapter return _channel.DisconnectAsync(); } - public Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout) + public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets) { - MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); + lock (_channel) + { + foreach (var packet in packets) + { + MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); + + var writeBuffer = PacketSerializer.Serialize(packet); + + _sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); + } + } - return ExecuteWithTimeoutAsync(PacketSerializer.SerializeAsync(packet, _channel), timeout); + await _sendTask; // configure await false geneates stackoverflow + await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); } public async Task ReceivePacketAsync(TimeSpan timeout) { - MqttBasePacket packet; + Tuple tuple; if (timeout > TimeSpan.Zero) { - packet = await ExecuteWithTimeoutAsync(PacketSerializer.DeserializeAsync(_channel), timeout).ConfigureAwait(false); + tuple = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); } else { - packet = await PacketSerializer.DeserializeAsync(_channel).ConfigureAwait(false); + tuple = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); } + var packet = PacketSerializer.Deserialize(tuple.Item1, tuple.Item2); + if (packet == null) { throw new MqttProtocolViolationException("Received malformed packet."); @@ -59,34 +78,27 @@ namespace MQTTnet.Core.Adapter return packet; } - private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) + private async Task> ReceiveAsync(Stream stream) { - var timeoutTask = Task.Delay(timeout); - if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) - { - throw new MqttCommunicationTimedOutException(); - } + var header = MqttPacketReader.ReadHeaderFromSource(stream); - if (task.IsFaulted) + MemoryStream body; + if (header.BodyLength > 0) { - throw new MqttCommunicationException(task.Exception); + var totalRead = 0; + do + { + var read = await stream.ReadAsync(_readBuffer, totalRead, header.BodyLength - totalRead).ConfigureAwait(false); + totalRead += read; + } while (totalRead < header.BodyLength); + body = new MemoryStream(_readBuffer, 0, header.BodyLength); } - - return task.Result; - } - - private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout) - { - var timeoutTask = Task.Delay(timeout); - if (await Task.WhenAny(timeoutTask, task).ConfigureAwait(false) == timeoutTask) + else { - throw new MqttCommunicationTimedOutException(); + body = new MemoryStream(); } - if (task.IsFaulted) - { - throw new MqttCommunicationException(task.Exception); - } + return Tuple.Create(header, body); } } } \ No newline at end of file diff --git a/MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs b/MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs new file mode 100644 index 0000000..d25d172 --- /dev/null +++ b/MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs @@ -0,0 +1,14 @@ +using System; +using System.Threading.Tasks; +using MQTTnet.Core.Packets; + +namespace MQTTnet.Core.Adapter +{ + public static class MqttCommunicationAdapterExtensions + { + public static Task SendPacketsAsync(this IMqttCommunicationAdapter adapter, TimeSpan timeout, params MqttBasePacket[] packets) + { + return adapter.SendPacketsAsync(timeout, packets); + } + } +} \ No newline at end of file diff --git a/MQTTnet.Core/Adapter/MqttConnectingFailedException.cs b/MQTTnet.Core/Adapter/MqttConnectingFailedException.cs index aaf94f4..045d0d1 100644 --- a/MQTTnet.Core/Adapter/MqttConnectingFailedException.cs +++ b/MQTTnet.Core/Adapter/MqttConnectingFailedException.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Core.Adapter { public class MqttConnectingFailedException : MqttCommunicationException { - public MqttConnectingFailedException(MqttConnectReturnCode returnCode) + public MqttConnectingFailedException(MqttConnectReturnCode returnCode) : base($"Connecting with MQTT server failed ({returnCode}).") { ReturnCode = returnCode; diff --git a/MQTTnet.Core/Channel/BufferedCommunicationChannel.cs b/MQTTnet.Core/Channel/BufferedCommunicationChannel.cs deleted file mode 100644 index 757f9c2..0000000 --- a/MQTTnet.Core/Channel/BufferedCommunicationChannel.cs +++ /dev/null @@ -1,77 +0,0 @@ -using System.Threading.Tasks; -using MQTTnet.Core.Client; -using System; - -namespace MQTTnet.Core.Channel -{ - public class BufferedCommunicationChannel : IMqttCommunicationChannel - { - private readonly IMqttCommunicationChannel _inner; - private int _bufferSize; - private int _bufferOffset; - - public BufferedCommunicationChannel(IMqttCommunicationChannel inner) - { - _inner = inner; - } - - public Task ConnectAsync(MqttClientOptions options) - { - return _inner.ConnectAsync(options); - } - - public Task DisconnectAsync() - { - return _inner.DisconnectAsync(); - } - - public int Peek() - { - return _inner.Peek(); - } - - public async Task> ReadAsync(int length, byte[] buffer) - { - //read from buffer - if (_bufferSize > 0) - { - return ReadFomBuffer(length, buffer); - } - - var available = _inner.Peek(); - // if there are less or equal bytes available then requested then just read em - if (available <= length) - { - return await _inner.ReadAsync(length, buffer); - } - - //if more bytes are available than requested do buffer them to reduce calls to network buffers - await WriteToBuffer(available, buffer).ConfigureAwait(false); - return ReadFomBuffer(length, buffer); - } - - private async Task WriteToBuffer(int available, byte[] buffer) - { - await _inner.ReadAsync(available, buffer).ConfigureAwait(false); - _bufferSize = available; - _bufferOffset = 0; - } - - private ArraySegment ReadFomBuffer(int length, byte[] buffer) - { - var result = new ArraySegment(buffer, _bufferOffset, length); - _bufferSize -= length; - _bufferOffset += length; - - if (_bufferSize < 0) - { - } - return result; - } - - public Task WriteAsync(byte[] buffer) - { - return _inner.WriteAsync(buffer); - } - } -} diff --git a/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs b/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs index 0f6ea4b..80c1308 100644 --- a/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs +++ b/MQTTnet.Core/Channel/IMqttCommunicationChannel.cs @@ -1,6 +1,6 @@ using System.Threading.Tasks; using MQTTnet.Core.Client; -using System; +using System.IO; namespace MQTTnet.Core.Channel { @@ -9,14 +9,11 @@ namespace MQTTnet.Core.Channel Task ConnectAsync(MqttClientOptions options); Task DisconnectAsync(); + + Stream SendStream { get; } - Task WriteAsync(byte[] buffer); + Stream ReceiveStream { get; } - /// - /// get the currently available number of bytes without reading them - /// - int Peek(); - - Task> ReadAsync(int length, byte[] buffer); + Stream RawStream { get; } } } diff --git a/MQTTnet.Core/Client/IMqttClient.cs b/MQTTnet.Core/Client/IMqttClient.cs index 0170adf..1b22edf 100644 --- a/MQTTnet.Core/Client/IMqttClient.cs +++ b/MQTTnet.Core/Client/IMqttClient.cs @@ -15,7 +15,7 @@ namespace MQTTnet.Core.Client Task ConnectAsync(MqttApplicationMessage willApplicationMessage = null); Task DisconnectAsync(); - Task PublishAsync(MqttApplicationMessage applicationMessage); + Task PublishAsync(IEnumerable applicationMessages); Task> SubscribeAsync(IList topicFilters); Task> SubscribeAsync(params TopicFilter[] topicFilters); Task Unsubscribe(IList topicFilters); diff --git a/MQTTnet.Core/Client/MqttClient.cs b/MQTTnet.Core/Client/MqttClient.cs index deb0242..aabc5a6 100644 --- a/MQTTnet.Core/Client/MqttClient.cs +++ b/MQTTnet.Core/Client/MqttClient.cs @@ -161,32 +161,43 @@ namespace MQTTnet.Core.Client return SendAndReceiveAsync(unsubscribePacket); } - public Task PublishAsync(MqttApplicationMessage applicationMessage) + public async Task PublishAsync(IEnumerable applicationMessages) { - if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); ThrowIfNotConnected(); - var publishPacket = applicationMessage.ToPublishPacket(); + var publishPackets = applicationMessages.Select(m => m.ToPublishPacket()); - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) + foreach (var qosGroup in publishPackets.GroupBy(p => p.QualityOfServiceLevel)) { - // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] - return SendAsync(publishPacket); - } - - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) - { - publishPacket.PacketIdentifier = GetNewPacketIdentifier(); - return SendAndReceiveAsync(publishPacket); - } - - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) - { - publishPacket.PacketIdentifier = GetNewPacketIdentifier(); - return PublishExactlyOncePacketAsync(publishPacket); + var qosPackets = qosGroup.ToArray(); + switch (qosGroup.Key) + { + case MqttQualityOfServiceLevel.AtMostOnce: + // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] + await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, qosPackets); + break; + case MqttQualityOfServiceLevel.AtLeastOnce: + { + foreach (var publishPacket in qosPackets) + { + publishPacket.PacketIdentifier = GetNewPacketIdentifier(); + await SendAndReceiveAsync(publishPacket); + } + break; + } + case MqttQualityOfServiceLevel.ExactlyOnce: + { + foreach (var publishPacket in qosPackets) + { + publishPacket.PacketIdentifier = GetNewPacketIdentifier(); + await PublishExactlyOncePacketAsync(publishPacket); + } + break; + } + default: + throw new InvalidOperationException(); + } } - - throw new InvalidOperationException(); } private async Task PublishExactlyOncePacketAsync(MqttBasePacket publishPacket) @@ -277,12 +288,14 @@ namespace MQTTnet.Core.Client if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) { FireApplicationMessageReceivedEvent(publishPacket); + return; } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) { FireApplicationMessageReceivedEvent(publishPacket); await SendAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + return; } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) @@ -295,6 +308,7 @@ namespace MQTTnet.Core.Client FireApplicationMessageReceivedEvent(publishPacket); await SendAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); + return; } throw new MqttCommunicationException("Received a not supported QoS level."); @@ -312,13 +326,12 @@ namespace MQTTnet.Core.Client private Task SendAsync(MqttBasePacket packet) { - return _adapter.SendPacketAsync(packet, _options.DefaultCommunicationTimeout); + return _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, packet); } private async Task SendAndReceiveAsync(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket { - await _adapter.SendPacketAsync(requestPacket, _options.DefaultCommunicationTimeout).ConfigureAwait(false); - + await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, requestPacket).ConfigureAwait(false); return (TResponsePacket)await _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.DefaultCommunicationTimeout).ConfigureAwait(false); } diff --git a/MQTTnet.Core/Client/MqttClientExtensions.cs b/MQTTnet.Core/Client/MqttClientExtensions.cs new file mode 100644 index 0000000..5e9875a --- /dev/null +++ b/MQTTnet.Core/Client/MqttClientExtensions.cs @@ -0,0 +1,12 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Core.Client +{ + public static class MqttClientExtensions + { + public static Task PublishAsync(this IMqttClient client, params MqttApplicationMessage[] applicationMessages) + { + return client.PublishAsync(applicationMessages); + } + } +} \ No newline at end of file diff --git a/MQTTnet.Core/Client/MqttPacketDispatcher.cs b/MQTTnet.Core/Client/MqttPacketDispatcher.cs index f057b97..1551d3f 100644 --- a/MQTTnet.Core/Client/MqttPacketDispatcher.cs +++ b/MQTTnet.Core/Client/MqttPacketDispatcher.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Threading.Tasks; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using System.Collections.Concurrent; @@ -13,7 +14,7 @@ namespace MQTTnet.Core.Client private readonly object _syncRoot = new object(); private readonly HashSet _receivedPackets = new HashSet(); private readonly ConcurrentDictionary> _packetByResponseType = new ConcurrentDictionary>(); - private readonly ConcurrentDictionary> _packetByIdentifier = new ConcurrentDictionary>(); + private readonly ConcurrentDictionary> _packetByIdentifier = new ConcurrentDictionary>(); public async Task WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) { @@ -22,16 +23,19 @@ namespace MQTTnet.Core.Client var packetAwaiter = AddPacketAwaiter(request, responseType); DispatchPendingPackets(); - var hasTimeout = await Task.WhenAny(Task.Delay(timeout), packetAwaiter.Task).ConfigureAwait(false) != packetAwaiter.Task; - RemovePacketAwaiter(request, responseType); - - if (hasTimeout) + try + { + return await packetAwaiter.Task.TimeoutAfter(timeout); + } + catch (MqttCommunicationTimedOutException) { MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); - throw new MqttCommunicationTimedOutException(); + throw; + } + finally + { + RemovePacketAwaiter(request, responseType); } - - return packetAwaiter.Task.Result; } public void Dispatch(MqttBasePacket packet) @@ -48,9 +52,9 @@ namespace MQTTnet.Core.Client packetDispatched = true; } } - else if (_packetByResponseType.TryRemove(packet.GetType(), out var tcs) ) + else if (_packetByResponseType.TryRemove(packet.GetType(), out var tcs)) { - tcs.TrySetResult( packet); + tcs.TrySetResult(packet); packetDispatched = true; } @@ -96,11 +100,11 @@ namespace MQTTnet.Core.Client { if (request is IMqttPacketWithIdentifier withIdent) { - _packetByIdentifier.TryRemove(withIdent.PacketIdentifier, out var tcs); + _packetByIdentifier.TryRemove(withIdent.PacketIdentifier, out var _); } else { - _packetByResponseType.TryRemove(responseType, out var tcs); + _packetByResponseType.TryRemove(responseType, out var _); } } diff --git a/MQTTnet.Core/Diagnostics/MqttTrace.cs b/MQTTnet.Core/Diagnostics/MqttTrace.cs index 1036028..5cce5ec 100644 --- a/MQTTnet.Core/Diagnostics/MqttTrace.cs +++ b/MQTTnet.Core/Diagnostics/MqttTrace.cs @@ -1,5 +1,4 @@ using System; -using System.Linq; namespace MQTTnet.Core.Diagnostics { diff --git a/MQTTnet.Core/Diagnostics/MqttTraceLevel.cs b/MQTTnet.Core/Diagnostics/MqttTraceLevel.cs index 0e7463f..86a0dc1 100644 --- a/MQTTnet.Core/Diagnostics/MqttTraceLevel.cs +++ b/MQTTnet.Core/Diagnostics/MqttTraceLevel.cs @@ -1,6 +1,6 @@ namespace MQTTnet.Core.Diagnostics { - public enum MqttTraceLevel + public enum MqttTraceLevel { Verbose, Information, diff --git a/MQTTnet.Core/Exceptions/MqttProtocolViolationException.cs b/MQTTnet.Core/Exceptions/MqttProtocolViolationException.cs index a4724d8..cef43f5 100644 --- a/MQTTnet.Core/Exceptions/MqttProtocolViolationException.cs +++ b/MQTTnet.Core/Exceptions/MqttProtocolViolationException.cs @@ -4,7 +4,7 @@ namespace MQTTnet.Core.Exceptions { public sealed class MqttProtocolViolationException : Exception { - public MqttProtocolViolationException(string message) + public MqttProtocolViolationException(string message) : base(message) { } diff --git a/MQTTnet.Core/Internal/AsyncAutoResetEvent.cs b/MQTTnet.Core/Internal/AsyncAutoResetEvent.cs deleted file mode 100644 index 795748d..0000000 --- a/MQTTnet.Core/Internal/AsyncAutoResetEvent.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System.Collections.Generic; -using System.Threading.Tasks; - -namespace MQTTnet.Core.Internal -{ - public sealed class AsyncGate - { - private readonly Queue> _waitingTasks = new Queue>(); - - public Task WaitOneAsync() - { - var tcs = new TaskCompletionSource(); - lock (_waitingTasks) - { - _waitingTasks.Enqueue(tcs); - } - - return tcs.Task; - } - - public void Set() - { - lock (_waitingTasks) - { - if (_waitingTasks.Count > 0) - { - _waitingTasks.Dequeue().SetResult(true); - } - } - } - } -} diff --git a/MQTTnet.Core/Internal/TaskExtensions.cs b/MQTTnet.Core/Internal/TaskExtensions.cs new file mode 100644 index 0000000..485bd57 --- /dev/null +++ b/MQTTnet.Core/Internal/TaskExtensions.cs @@ -0,0 +1,55 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Core.Exceptions; + +namespace MQTTnet.Core.Internal +{ + public static class TaskExtensions + { + public static Task TimeoutAfter(this Task task, TimeSpan timeout) + { + return TimeoutAfter(task.ContinueWith(t => 0), timeout); + } + + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + using (var cancellationTokenSource = new CancellationTokenSource()) + { + var tcs = new TaskCompletionSource(); + + cancellationTokenSource.Token.Register(() => + { + tcs.TrySetCanceled(); + }); + + try + { + cancellationTokenSource.CancelAfter(timeout); + task.ContinueWith(t => + { + if (t.IsFaulted) + { + tcs.TrySetException(t.Exception); + } + + if (t.IsCompleted) + { + tcs.TrySetResult(t.Result); + } + }, cancellationTokenSource.Token); + + return await tcs.Task; + } + catch (TaskCanceledException) + { + throw new MqttCommunicationTimedOutException(); + } + catch (Exception e) + { + throw new MqttCommunicationException(e); + } + } + } + } +} diff --git a/MQTTnet.Core/MQTTnet.Core.csproj b/MQTTnet.Core/MQTTnet.Core.csproj index 480b4c7..797bb85 100644 --- a/MQTTnet.Core/MQTTnet.Core.csproj +++ b/MQTTnet.Core/MQTTnet.Core.csproj @@ -5,6 +5,7 @@ MQTTnet.Core MQTTnet.Core False + Full diff --git a/MQTTnet.Core/Packets/MqttPublishPacket.cs b/MQTTnet.Core/Packets/MqttPublishPacket.cs index 8476ccb..447aaed 100644 --- a/MQTTnet.Core/Packets/MqttPublishPacket.cs +++ b/MQTTnet.Core/Packets/MqttPublishPacket.cs @@ -17,7 +17,7 @@ namespace MQTTnet.Core.Packets public override string ToString() { - return nameof(MqttPublishPacket) + + return nameof(MqttPublishPacket) + ": [Topic=" + Topic + "]" + " [Payload=" + Convert.ToBase64String(Payload) + "]" + " [QoSLevel=" + QualityOfServiceLevel + "]" + diff --git a/MQTTnet.Core/Packets/MqttSubscribePacket.cs b/MQTTnet.Core/Packets/MqttSubscribePacket.cs index 6f4877b..3fa379b 100644 --- a/MQTTnet.Core/Packets/MqttSubscribePacket.cs +++ b/MQTTnet.Core/Packets/MqttSubscribePacket.cs @@ -6,7 +6,7 @@ namespace MQTTnet.Core.Packets public sealed class MqttSubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier { public ushort PacketIdentifier { get; set; } - + public IList TopicFilters { get; set; } = new List(); public override string ToString() diff --git a/MQTTnet.Core/Packets/MqttUnsubscribe.cs b/MQTTnet.Core/Packets/MqttUnsubscribe.cs index b6cfab6..269bfcb 100644 --- a/MQTTnet.Core/Packets/MqttUnsubscribe.cs +++ b/MQTTnet.Core/Packets/MqttUnsubscribe.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Core.Packets public sealed class MqttUnsubscribePacket : MqttBasePacket, IMqttPacketWithIdentifier { public ushort PacketIdentifier { get; set; } - + public IList TopicFilters { get; set; } = new List(); } } diff --git a/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs b/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs index 801b7ea..df5b045 100644 --- a/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs +++ b/MQTTnet.Core/Serializer/IMqttPacketSerializer.cs @@ -1,5 +1,4 @@ -using System.Threading.Tasks; -using MQTTnet.Core.Channel; +using System.IO; using MQTTnet.Core.Packets; namespace MQTTnet.Core.Serializer @@ -8,8 +7,8 @@ namespace MQTTnet.Core.Serializer { MqttProtocolVersion ProtocolVersion { get; set; } - Task SerializeAsync(MqttBasePacket mqttPacket, IMqttCommunicationChannel destination); + byte[] Serialize(MqttBasePacket mqttPacket); - Task DeserializeAsync(IMqttCommunicationChannel source); + MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream stream); } } \ No newline at end of file diff --git a/MQTTnet.Core/Serializer/MqttPacketReader.cs b/MQTTnet.Core/Serializer/MqttPacketReader.cs index bdef7fa..def590b 100644 --- a/MQTTnet.Core/Serializer/MqttPacketReader.cs +++ b/MQTTnet.Core/Serializer/MqttPacketReader.cs @@ -1,10 +1,8 @@ using System; using System.IO; using System.Text; -using System.Threading.Tasks; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Protocol; -using MQTTnet.Core.Channel; using MQTTnet.Core.Packets; namespace MQTTnet.Core.Serializer @@ -13,14 +11,14 @@ namespace MQTTnet.Core.Serializer { private readonly MqttPacketHeader _header; - public MqttPacketReader(MqttPacketHeader header, Stream body) - : base(body) + public MqttPacketReader(Stream stream, MqttPacketHeader header) + : base(stream, Encoding.UTF8, true) { _header = header; } - + public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; - + public override ushort ReadUInt16() { var buffer = ReadBytes(2); @@ -49,14 +47,11 @@ namespace MQTTnet.Core.Serializer return ReadBytes(_header.BodyLength - (int)BaseStream.Position); } - public static async Task ReadHeaderFromSourceAsync(IMqttCommunicationChannel source, byte[] buffer) + public static MqttPacketHeader ReadHeaderFromSource(Stream stream) { - var fixedHeader = await ReadStreamByteAsync(source, buffer).ConfigureAwait(false); - var byteReader = new ByteReader(fixedHeader); - byteReader.Read(4); - - var controlPacketType = (MqttControlPacketType)byteReader.Read(4); - var bodyLength = await ReadBodyLengthFromSourceAsync(source, buffer).ConfigureAwait(false); + var fixedHeader = (byte)stream.ReadByte(); + var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); + var bodyLength = ReadBodyLengthFromSource(stream); return new MqttPacketHeader { @@ -66,25 +61,7 @@ namespace MQTTnet.Core.Serializer }; } - private static async Task ReadStreamByteAsync(IMqttCommunicationChannel source, byte[] readBuffer) - { - var result = await ReadFromSourceAsync(source, 1, readBuffer).ConfigureAwait(false); - return result.Array[result.Offset]; - } - - public static async Task> ReadFromSourceAsync(IMqttCommunicationChannel source, int length, byte[] buffer) - { - try - { - return await source.ReadAsync(length, buffer); - } - catch (Exception exception) - { - throw new MqttCommunicationException(exception); - } - } - - private static async Task ReadBodyLengthFromSourceAsync(IMqttCommunicationChannel source, byte[] buffer) + private static int ReadBodyLengthFromSource(Stream stream) { // Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. var multiplier = 1; @@ -92,7 +69,7 @@ namespace MQTTnet.Core.Serializer byte encodedByte; do { - encodedByte = await ReadStreamByteAsync(source, buffer).ConfigureAwait(false); + encodedByte = (byte)stream.ReadByte(); value += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128 * 128 * 128) diff --git a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs index 4de39af..5dea752 100644 --- a/MQTTnet.Core/Serializer/MqttPacketSerializer.cs +++ b/MQTTnet.Core/Serializer/MqttPacketSerializer.cs @@ -3,8 +3,6 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; -using System.Threading.Tasks; -using MQTTnet.Core.Channel; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Packets; using MQTTnet.Core.Protocol; @@ -17,12 +15,10 @@ namespace MQTTnet.Core.Serializer private static byte[] ProtocolVersionV310Name { get; } = Encoding.UTF8.GetBytes("MQIs"); public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; - private byte[] _readBuffer = new byte[BufferConstants.Size]; // TODO: What happens if the message is bigger? - public async Task SerializeAsync(MqttBasePacket packet, IMqttCommunicationChannel destination) + public byte[] Serialize(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); - if (destination == null) throw new ArgumentNullException(nameof(destination)); using (var stream = new MemoryStream()) using (var writer = new MqttPacketWriter(stream)) @@ -30,9 +26,12 @@ namespace MQTTnet.Core.Serializer var header = new List { SerializePacket(packet, writer) }; var body = stream.ToArray(); MqttPacketWriter.BuildLengthHeader(body.Length, header); - - await destination.WriteAsync(header.ToArray()).ConfigureAwait(false); - await destination.WriteAsync(body).ConfigureAwait(false); + var headerArray = header.ToArray(); + var writeBuffer = new byte[header.Count + body.Length]; + Buffer.BlockCopy(headerArray, 0, writeBuffer, 0, headerArray.Length); + Buffer.BlockCopy(body, 0, writeBuffer, headerArray.Length, body.Length); + + return writeBuffer; } } @@ -111,121 +110,114 @@ namespace MQTTnet.Core.Serializer throw new MqttProtocolViolationException("Packet type invalid."); } - public async Task DeserializeAsync(IMqttCommunicationChannel source) + public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) { - if (source == null) throw new ArgumentNullException(nameof(source)); - - var header = await MqttPacketReader.ReadHeaderFromSourceAsync(source, _readBuffer).ConfigureAwait(false); - var body = await GetBody(source, header).ConfigureAwait(false); + if (header == null) throw new ArgumentNullException(nameof(header)); + if (body == null) throw new ArgumentNullException(nameof(body)); - using (var mqttPacketReader = new MqttPacketReader(header, body)) + using (var reader = new MqttPacketReader(body, header)) { - switch (header.ControlPacketType) - { - case MqttControlPacketType.Connect: - { - return DeserializeConnect(mqttPacketReader); - } - - case MqttControlPacketType.ConnAck: - { - return DeserializeConnAck(mqttPacketReader); - } - - case MqttControlPacketType.Disconnect: - { - return new MqttDisconnectPacket(); - } - - case MqttControlPacketType.Publish: - { - return DeserializePublish(mqttPacketReader, header); - } - - case MqttControlPacketType.PubAck: - { - return new MqttPubAckPacket - { - PacketIdentifier = mqttPacketReader.ReadUInt16() - }; - } - - case MqttControlPacketType.PubRec: - { - return new MqttPubRecPacket - { - PacketIdentifier = mqttPacketReader.ReadUInt16() - }; - } - - case MqttControlPacketType.PubRel: - { - return new MqttPubRelPacket - { - PacketIdentifier = mqttPacketReader.ReadUInt16() - }; - } - - case MqttControlPacketType.PubComp: - { - return new MqttPubCompPacket - { - PacketIdentifier = mqttPacketReader.ReadUInt16() - }; - } - - case MqttControlPacketType.PingReq: - { - return new MqttPingReqPacket(); - } + return Deserialize(header, reader); + } + } - case MqttControlPacketType.PingResp: + private static MqttBasePacket Deserialize(MqttPacketHeader header, MqttPacketReader reader) + { + switch (header.ControlPacketType) + { + case MqttControlPacketType.Connect: + { + return DeserializeConnect(reader); + } + + case MqttControlPacketType.ConnAck: + { + return DeserializeConnAck(reader); + } + + case MqttControlPacketType.Disconnect: + { + return new MqttDisconnectPacket(); + } + + case MqttControlPacketType.Publish: + { + return DeserializePublish(reader, header); + } + + case MqttControlPacketType.PubAck: + { + return new MqttPubAckPacket { - return new MqttPingRespPacket(); - } + PacketIdentifier = reader.ReadUInt16() + }; + } - case MqttControlPacketType.Subscribe: + case MqttControlPacketType.PubRec: + { + return new MqttPubRecPacket { - return DeserializeSubscribe(mqttPacketReader); - } + PacketIdentifier = reader.ReadUInt16() + }; + } - case MqttControlPacketType.SubAck: + case MqttControlPacketType.PubRel: + { + return new MqttPubRelPacket { - return DeserializeSubAck(mqttPacketReader); - } + PacketIdentifier = reader.ReadUInt16() + }; + } - case MqttControlPacketType.Unsubscibe: + case MqttControlPacketType.PubComp: + { + return new MqttPubCompPacket { - return DeserializeUnsubscribe(mqttPacketReader); - } - - case MqttControlPacketType.UnsubAck: + PacketIdentifier = reader.ReadUInt16() + }; + } + + case MqttControlPacketType.PingReq: + { + return new MqttPingReqPacket(); + } + + case MqttControlPacketType.PingResp: + { + return new MqttPingRespPacket(); + } + + case MqttControlPacketType.Subscribe: + { + return DeserializeSubscribe(reader); + } + + case MqttControlPacketType.SubAck: + { + return DeserializeSubAck(reader); + } + + case MqttControlPacketType.Unsubscibe: + { + return DeserializeUnsubscribe(reader); + } + + case MqttControlPacketType.UnsubAck: + { + return new MqttUnsubAckPacket { - return new MqttUnsubAckPacket - { - PacketIdentifier = mqttPacketReader.ReadUInt16() - }; - } + PacketIdentifier = reader.ReadUInt16() + }; + } - default: - { - throw new MqttProtocolViolationException($"Packet type ({(int)header.ControlPacketType}) not supported."); - } - } + default: + { + throw new MqttProtocolViolationException( + $"Packet type ({(int)header.ControlPacketType}) not supported."); + } } } - private async Task GetBody(IMqttCommunicationChannel source, MqttPacketHeader header) - { - if (header.BodyLength > 0) - { - var segment = await MqttPacketReader.ReadFromSourceAsync(source, header.BodyLength, _readBuffer).ConfigureAwait(false); - return new MemoryStream(segment.Array, segment.Offset, segment.Count); - } - - return new MemoryStream(); - } - private static MqttBasePacket DeserializeUnsubscribe(MqttPacketReader reader) { var packet = new MqttUnsubscribePacket @@ -514,12 +506,21 @@ namespace MQTTnet.Core.Serializer writer.Write(packet.Payload); } - var fixedHeader = new ByteWriter(); - fixedHeader.Write(packet.Retain); - fixedHeader.Write((byte)packet.QualityOfServiceLevel, 2); - fixedHeader.Write(packet.Dup); + byte fixedHeader = 0; + + if (packet.Retain) + { + fixedHeader |= 0x01; + } + + fixedHeader |= (byte)((byte)packet.QualityOfServiceLevel << 1); + + if (packet.Dup) + { + fixedHeader |= 0x08; + } - return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader.Value); + return MqttPacketWriter.BuildFixedHeader(MqttControlPacketType.Publish, fixedHeader); } private static byte Serialize(MqttPubAckPacket packet, MqttPacketWriter writer) diff --git a/MQTTnet.Core/Serializer/MqttPacketWriter.cs b/MQTTnet.Core/Serializer/MqttPacketWriter.cs index ae9eef1..d8c997d 100644 --- a/MQTTnet.Core/Serializer/MqttPacketWriter.cs +++ b/MQTTnet.Core/Serializer/MqttPacketWriter.cs @@ -8,10 +8,9 @@ namespace MQTTnet.Core.Serializer { public sealed class MqttPacketWriter : BinaryWriter { - public MqttPacketWriter( Stream stream ) + public MqttPacketWriter(Stream stream) : base(stream) { - } public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0) @@ -20,7 +19,7 @@ namespace MQTTnet.Core.Serializer fixedHeader |= flags; return (byte)fixedHeader; } - + public override void Write(ushort value) { var buffer = BitConverter.GetBytes(value); @@ -43,7 +42,7 @@ namespace MQTTnet.Core.Serializer Write(value.Value); } - + public void WriteWithLengthPrefix(string value) { WriteWithLengthPrefix(Encoding.UTF8.GetBytes(value ?? string.Empty)); diff --git a/MQTTnet.Core/Server/MqttClientMessageQueue.cs b/MQTTnet.Core/Server/MqttClientMessageQueue.cs index 22fbdbb..3cbe882 100644 --- a/MQTTnet.Core/Server/MqttClientMessageQueue.cs +++ b/MQTTnet.Core/Server/MqttClientMessageQueue.cs @@ -1,20 +1,17 @@ using System; -using System.Collections.Generic; -using System.Linq; +using System.Collections.Concurrent; using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Adapter; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; -using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; namespace MQTTnet.Core.Server { public sealed class MqttClientMessageQueue { - private readonly List _pendingPublishPackets = new List(); - private readonly AsyncGate _gate = new AsyncGate(); + private readonly BlockingCollection _pendingPublishPackets = new BlockingCollection(); private readonly MqttServerOptions _options; private CancellationTokenSource _cancellationTokenSource; @@ -43,26 +40,22 @@ namespace MQTTnet.Core.Server _adapter = null; _cancellationTokenSource?.Cancel(); _cancellationTokenSource = null; + _pendingPublishPackets?.Dispose(); } public void Enqueue(MqttPublishPacket publishPacket) { if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); - lock (_pendingPublishPackets) - { - _pendingPublishPackets.Add(new MqttClientPublishPacketContext(publishPacket)); - _gate.Set(); - } + _pendingPublishPackets.Add(new MqttClientPublishPacketContext(publishPacket)); } private async Task SendPendingPublishPacketsAsync(CancellationToken cancellationToken) { - while (!cancellationToken.IsCancellationRequested) + foreach (var publishPacket in _pendingPublishPackets.GetConsumingEnumerable(cancellationToken)) { try { - await _gate.WaitOneAsync().ConfigureAwait(false); if (cancellationToken.IsCancellationRequested) { return; @@ -73,25 +66,12 @@ namespace MQTTnet.Core.Server continue; } - List pendingPublishPackets; - lock (_pendingPublishPackets) - { - pendingPublishPackets = _pendingPublishPackets.ToList(); - } - - foreach (var publishPacket in pendingPublishPackets) - { - await TrySendPendingPublishPacketAsync(publishPacket).ConfigureAwait(false); - } + await TrySendPendingPublishPacketAsync(publishPacket).ConfigureAwait(false); } catch (Exception e) { MqttTrace.Error(nameof(MqttClientMessageQueue), e, "Error while sending pending publish packets."); } - finally - { - Cleanup(); - } } } @@ -105,30 +85,24 @@ namespace MQTTnet.Core.Server } publishPacketContext.PublishPacket.Dup = publishPacketContext.SendTries > 0; - await _adapter.SendPacketAsync(publishPacketContext.PublishPacket, _options.DefaultCommunicationTimeout).ConfigureAwait(false); + await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, publishPacketContext.PublishPacket).ConfigureAwait(false); publishPacketContext.IsSent = true; } catch (MqttCommunicationException exception) { MqttTrace.Warning(nameof(MqttClientMessageQueue), exception, "Sending publish packet failed."); + _pendingPublishPackets.Add(publishPacketContext); } catch (Exception exception) { MqttTrace.Error(nameof(MqttClientMessageQueue), exception, "Sending publish packet failed."); + _pendingPublishPackets.Add(publishPacketContext); } finally { publishPacketContext.SendTries++; } } - - private void Cleanup() - { - lock (_pendingPublishPackets) - { - _pendingPublishPackets.RemoveAll(p => p.IsSent); - } - } } } diff --git a/MQTTnet.Core/Server/MqttClientSession.cs b/MQTTnet.Core/Server/MqttClientSession.cs index 9b811ec..1411dd3 100644 --- a/MQTTnet.Core/Server/MqttClientSession.cs +++ b/MQTTnet.Core/Server/MqttClientSession.cs @@ -103,12 +103,12 @@ namespace MQTTnet.Core.Server { if (packet is MqttSubscribePacket subscribePacket) { - return Adapter.SendPacketAsync(_subscriptionsManager.Subscribe(subscribePacket), _options.DefaultCommunicationTimeout); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _subscriptionsManager.Subscribe(subscribePacket)); } if (packet is MqttUnsubscribePacket unsubscribePacket) { - return Adapter.SendPacketAsync(_subscriptionsManager.Unsubscribe(unsubscribePacket), _options.DefaultCommunicationTimeout); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _subscriptionsManager.Unsubscribe(unsubscribePacket)); } if (packet is MqttPublishPacket publishPacket) @@ -123,7 +123,7 @@ namespace MQTTnet.Core.Server if (packet is MqttPubRecPacket pubRecPacket) { - return Adapter.SendPacketAsync(pubRecPacket.CreateResponse(), _options.DefaultCommunicationTimeout); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, pubRecPacket.CreateResponse()); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) @@ -134,7 +134,7 @@ namespace MQTTnet.Core.Server if (packet is MqttPingReqPacket) { - return Adapter.SendPacketAsync(new MqttPingRespPacket(), _options.DefaultCommunicationTimeout); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPingRespPacket()); } if (packet is MqttDisconnectPacket || packet is MqttConnectPacket) @@ -160,7 +160,7 @@ namespace MQTTnet.Core.Server if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) { _publishPacketReceivedCallback(this, publishPacket); - return Adapter.SendPacketAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }, _options.DefaultCommunicationTimeout); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier }); } if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) @@ -173,7 +173,7 @@ namespace MQTTnet.Core.Server _publishPacketReceivedCallback(this, publishPacket); - return Adapter.SendPacketAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }, _options.DefaultCommunicationTimeout); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier }); } throw new MqttCommunicationException("Received a not supported QoS level."); @@ -186,7 +186,7 @@ namespace MQTTnet.Core.Server _unacknowledgedPublishPackets.Remove(pubRelPacket.PacketIdentifier); } - return Adapter.SendPacketAsync(new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }, _options.DefaultCommunicationTimeout); + return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier }); } } } diff --git a/MQTTnet.Core/Server/MqttClientSessionsManager.cs b/MQTTnet.Core/Server/MqttClientSessionsManager.cs index b8ff4e6..483398d 100644 --- a/MQTTnet.Core/Server/MqttClientSessionsManager.cs +++ b/MQTTnet.Core/Server/MqttClientSessionsManager.cs @@ -40,21 +40,21 @@ namespace MQTTnet.Core.Server var connectReturnCode = ValidateConnection(connectPacket); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { - await eventArgs.ClientAdapter.SendPacketAsync(new MqttConnAckPacket + await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttConnAckPacket { ConnectReturnCode = connectReturnCode - }, _options.DefaultCommunicationTimeout).ConfigureAwait(false); + }).ConfigureAwait(false); return; } var clientSession = GetOrCreateClientSession(connectPacket); - await eventArgs.ClientAdapter.SendPacketAsync(new MqttConnAckPacket + await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttConnAckPacket { ConnectReturnCode = connectReturnCode, IsSessionPresent = clientSession.IsExistingSession - }, _options.DefaultCommunicationTimeout).ConfigureAwait(false); + }).ConfigureAwait(false); await clientSession.Session.RunAsync(eventArgs.Identifier, connectPacket.WillMessage, eventArgs.ClientAdapter).ConfigureAwait(false); } diff --git a/MQTTnet.Core/Server/MqttServer.cs b/MQTTnet.Core/Server/MqttServer.cs index f3025e6..b6f68ee 100644 --- a/MQTTnet.Core/Server/MqttServer.cs +++ b/MQTTnet.Core/Server/MqttServer.cs @@ -20,7 +20,7 @@ namespace MQTTnet.Core.Server { _options = options ?? throw new ArgumentNullException(nameof(options)); _adapters = adapters ?? throw new ArgumentNullException(nameof(adapters)); - + _clientSessionsManager = new MqttClientSessionsManager(options); _clientSessionsManager.ApplicationMessageReceived += (s, e) => ApplicationMessageReceived?.Invoke(s, e); } @@ -61,7 +61,7 @@ namespace MQTTnet.Core.Server adapter.ClientConnected += OnClientConnected; adapter.Start(_options); } - + MqttTrace.Information(nameof(MqttServer), "Started."); } diff --git a/MQTTnet.Core/Server/MqttServerOptions.cs b/MQTTnet.Core/Server/MqttServerOptions.cs index 5edf500..9de3627 100644 --- a/MQTTnet.Core/Server/MqttServerOptions.cs +++ b/MQTTnet.Core/Server/MqttServerOptions.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Core.Server public MqttServerDefaultEndpointOptions DefaultEndpointOptions { get; } = new MqttServerDefaultEndpointOptions(); public MqttServerTlsEndpointOptions TlsEndpointOptions { get; } = new MqttServerTlsEndpointOptions(); - + public int ConnectionBacklog { get; set; } = 10; public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15); diff --git a/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj b/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj index 7b69707..7bbbb86 100644 --- a/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj +++ b/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj @@ -2,6 +2,7 @@ Exe + Full netcoreapp2.0 diff --git a/README.md b/README.md index 6950a4e..f281ba6 100644 --- a/README.md +++ b/README.md @@ -10,13 +10,14 @@ MQTTnet is a .NET library for MQTT based communication. It provides a MQTT clien # Features ## General +* Performance optimized (publishing ~18.000 messages per second on local machine) * Async support * TLS 1.2 support for client and server (but not UWP servers) * Extensible communication channels (i.e. In-Memory, TCP, TCP+TLS, WS) * Interfaces included for mocking and testing * Lightweight (only the low level implementation of MQTT, no overhead) * Access to internal trace messages -* Unit tested (50+ tests) +* Unit tested (55+ tests) ## Client * Rx support (via another project) diff --git a/Tests/MQTTnet.Core.Tests/ExtensionTests.cs b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs new file mode 100644 index 0000000..1a3cd29 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/ExtensionTests.cs @@ -0,0 +1,33 @@ +using System; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Core.Exceptions; +using MQTTnet.Core.Internal; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class ExtensionTests + { + [ExpectedException(typeof( MqttCommunicationTimedOutException ) )] + [TestMethod] + public async Task TestTimeoutAfter() + { + await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + } + + [ExpectedException(typeof( MqttCommunicationTimedOutException))] + [TestMethod] + public async Task TestTimeoutAfterWithResult() + { + await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100)); + } + + [TestMethod] + public async Task TestTimeoutAfterCompleteInTime() + { + var result = await Task.Delay( TimeSpan.FromMilliseconds( 100 ) ).ContinueWith( t => 5 ).TimeoutAfter( TimeSpan.FromMilliseconds( 500 ) ); + Assert.AreEqual( 5, result ); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj b/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj index 88bf8f2..2e20398 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj +++ b/Tests/MQTTnet.Core.Tests/MQTTnet.Core.Tests.csproj @@ -86,6 +86,7 @@ + diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 6a18f50..e0ae8fa 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -391,6 +391,12 @@ namespace MQTTnet.Core.Tests { private readonly MemoryStream _stream = new MemoryStream(); + public Stream ReceiveStream => _stream; + + public Stream RawStream => _stream; + + public Stream SendStream => _stream; + public bool IsConnected { get; } = true; public TestChannel() @@ -413,34 +419,16 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task WriteAsync(byte[] buffer) - { - return _stream.WriteAsync(buffer, 0, buffer.Length); - } - - public async Task> ReadAsync(int length, byte[] buffer) - { - await _stream.ReadAsync(buffer, 0, length); - return new ArraySegment(buffer, 0, length); - } - public byte[] ToArray() { return _stream.ToArray(); } - - public int Peek() - { - return (int)_stream.Length - (int)_stream.Position; - } } private static void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) { var serializer = new MqttPacketSerializer { ProtocolVersion = protocolVersion }; - var channel = new TestChannel(); - serializer.SerializeAsync(packet, channel).Wait(); - var buffer = channel.ToArray(); + var buffer = serializer.Serialize(packet); Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(buffer)); } @@ -448,19 +436,21 @@ namespace MQTTnet.Core.Tests private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value) { var serializer = new MqttPacketSerializer(); + + var buffer1 = serializer.Serialize(packet); - var channel1 = new TestChannel(); - serializer.SerializeAsync(packet, channel1).Wait(); - var buffer1 = channel1.ToArray(); - - var channel2 = new TestChannel(buffer1); - var deserializedPacket = serializer.DeserializeAsync(channel2).Result; - var buffer2 = channel2.ToArray(); + using (var headerStream = new MemoryStream( buffer1 )) + { + var header = MqttPacketReader.ReadHeaderFromSource( headerStream ); - var channel3 = new TestChannel(buffer2); - serializer.SerializeAsync(deserializedPacket, channel3).Wait(); + using (var bodyStream = new MemoryStream( buffer1, (int)headerStream.Position, header.BodyLength )) + { + var deserializedPacket = serializer.Deserialize(header, bodyStream); + var buffer2 = serializer.Serialize( deserializedPacket ); - Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(channel3.ToArray())); + Assert.AreEqual( expectedBase64Value, Convert.ToBase64String( buffer2 ) ); + } + } } } } diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index fe6e448..e3e9028 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Adapter; using MQTTnet.Core.Client; @@ -26,11 +28,15 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout) + public Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets) { ThrowIfPartnerIsNull(); - Partner.SendPacketInternal(packet); + foreach (var packet in packets) + { + Partner.SendPacketInternal(packet); + } + return Task.FromResult(0); } @@ -41,6 +47,11 @@ namespace MQTTnet.Core.Tests return Task.Run(() => _incomingPackets.Take()); } + public IEnumerable ReceivePackets( CancellationToken cancellationToken ) + { + return _incomingPackets.GetConsumingEnumerable(); + } + private void SendPacketInternal(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); diff --git a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs index 1933999..8e0f430 100644 --- a/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs +++ b/Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace MQTTnet.TestApp.NetFramework @@ -17,12 +18,12 @@ namespace MQTTnet.TestApp.NetFramework public static async Task RunAsync() { var server = Task.Run(() => RunServerAsync()); - var client = Task.Run(() => RunClientAsync(1000, 50000, TimeSpan.FromMilliseconds(10))); + var client = Task.Run(() => RunClientAsync(300, TimeSpan.FromMilliseconds(10))); await Task.WhenAll(server, client).ConfigureAwait(false); } - private static async Task RunClientAsync(int messageChunkSize, int totalMessageCount, TimeSpan interval) + private static async Task RunClientAsync( int msgChunkSize, TimeSpan interval ) { try { @@ -77,33 +78,64 @@ namespace MQTTnet.TestApp.NetFramework Console.WriteLine("### WAITING FOR APPLICATION MESSAGES ###"); - var applicationMessage = new MqttApplicationMessage( - "A/B/C", - Encoding.UTF8.GetBytes("Hello World"), - MqttQualityOfServiceLevel.AtLeastOnce, - false - ); + var testMessageCount = 1000; + var message = CreateMessage(); + var stopwatch = Stopwatch.StartNew(); + for (var i = 0; i < testMessageCount; i++) + { + await client.PublishAsync(message); + } - var overallCount = 0; - while (overallCount < totalMessageCount) + stopwatch.Stop(); + Console.WriteLine($"Sent 1000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message)."); + + stopwatch.Restart(); + var sentMessagesCount = 0; + while (stopwatch.ElapsedMilliseconds < 1000) { - var stopwatch = Stopwatch.StartNew(); - var count = 0; - for (var i = 0; i < messageChunkSize; i++) + await client.PublishAsync(message); + sentMessagesCount++; + } + + Console.WriteLine($"Sending {sentMessagesCount} messages per second."); + + var last = DateTime.Now; + var msgCount = 0; + + while (true) + { + var msgs = Enumerable.Range( 0, msgChunkSize ) + .Select( i => CreateMessage() ) + .ToList(); + + if (false) + { + //send concurrent (test for raceconditions) + var sendTasks = msgs + .Select( msg => PublishSingleMessage( client, msg, ref msgCount ) ) + .ToList(); + + await Task.WhenAll( sendTasks ); + } + else { - //do not await to send as much messages as possible - await client.PublishAsync(applicationMessage).ConfigureAwait(false); - count++; - overallCount++; + await client.PublishAsync( msgs ); + msgCount += msgs.Count; + //send multiple } - stopwatch.Stop(); + + + var now = DateTime.Now; + if (last < now - TimeSpan.FromSeconds(1)) + { + Console.WriteLine( $"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}" ); + msgCount = 0; + last = now; + } - Console.WriteLine($"Sent {count} messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)count} ms / message)."); await Task.Delay(interval).ConfigureAwait(false); } - - Console.WriteLine($"Completed sending {totalMessageCount} messages."); } catch (Exception exception) { @@ -111,6 +143,25 @@ namespace MQTTnet.TestApp.NetFramework } } + private static MqttApplicationMessage CreateMessage() + { + return new MqttApplicationMessage( + "A/B/C", + Encoding.UTF8.GetBytes( "Hello World" ), + MqttQualityOfServiceLevel.AtMostOnce, + false + ); + } + + private static Task PublishSingleMessage( IMqttClient client, MqttApplicationMessage applicationMessage, ref int count ) + { + Interlocked.Increment( ref count ); + return Task.Run( () => + { + return client.PublishAsync( applicationMessage ); + } ); + } + private static void RunServerAsync() { try @@ -133,12 +184,12 @@ namespace MQTTnet.TestApp.NetFramework }; var mqttServer = new MqttServerFactory().CreateMqttServer(options); - var last = DateTime.UtcNow; + var last = DateTime.Now; var msgs = 0; mqttServer.ApplicationMessageReceived += (sender, args) => { msgs++; - var now = DateTime.UtcNow; + var now = DateTime.Now; if (last < now - TimeSpan.FromSeconds(1)) { Console.WriteLine($"received {msgs}");